In [1]:
import numpy as np
import pandas as pd
import pickle
import gc
import os
import pysam
import torch
from torch.utils.data import DataLoader, Dataset
from itertools import chain

In [2]:
from encoding_utils import sequence_encoders
import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from models.spec_dss import DSSResNetEmb, SpecAdd

# Train/Test Model

**0. Specify input parameters**

In [3]:
input_params = misc.dotdict({})

input_params.dataset =  '/vol/storage/ouologuems/other/systems_genetics/datasets/phase3_top10/dataset.parquet'
input_params.model_weight = '/vol/storage/ouologuems/other/systems_genetics/checkpoints/phase3_top10/aware_large_splitmsk/weights/epoch_100_weights_model.pt'
#input_params.species_list = datadir + 'fasta/240_species/240_species.txt'

input_params.output_dir = '/vol/storage/ouologuems/other/systems_genetics/test'

input_params.split_mask = False
input_params.mask_rate = 0.2 #[0.012,0.2]#RAN #single float or 2 floats for reference and alternative
input_params.masking = 'none' # stratified_maf or none

input_params.test = True

input_params.get_embeddings = True
input_params.mask_at_test = True

input_params.agnostic = False

input_params.seq_len = 5000

input_params.tot_epochs = 1
input_params.fold = 0
input_params.Nfolds = 5

input_params.train_splits = 1

input_params.save_at = [1]
input_params.validate_every = 1

input_params.d_model = 256
input_params.n_layers = 16
input_params.dropout = 0.

input_params.batch_size = 1
input_params.learning_rate = 1e-4
input_params.weight_decay = 0

if input_params.dataset.endswith('.fa'):
    seq_df = pd.read_csv(input_params.dataset + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
elif input_params.dataset.endswith('.parquet'):
    seq_df = pd.read_parquet(input_params.dataset).reset_index()
    
seq_df[['split','sample_id','seg_name']] =  seq_df['seq_name'].str.split(':',expand=True)

if not input_params.agnostic:
    #for segment-aware model, assign a label to each segment
    seg_name = seq_df.seq_name.apply(lambda x:':'.join(x.split(':')[2:]))
    segment_encoding = seg_name.drop_duplicates().reset_index(drop=True)
    segment_encoding = {seg_name:idx for idx,seg_name in segment_encoding.items()}
    seq_df['seg_label'] = seg_name.map(segment_encoding)
else:
    seq_df['seg_label'] = 0


if input_params.test:
    seq_df = seq_df[seq_df.split=='test']
else:
    seq_df = seq_df[seq_df.split!='test']

In [4]:
seq_df.head()

Unnamed: 0,seq_name,seq,split,sample_id,seg_name,seg_label
20010,test:NA20795:ENSG00000198502.5,FBFFFFFRMRMMMFBFFMBRBRFFRFFFFFFFFMFFBFFFFFFFFB...,test,NA20795,ENSG00000198502.5,4
20011,test:HG00260:ENSG00000214425.1,RRBRRRBRRRBBRRBBBBRBBBRRBRBRRBRRRRBBBRBRRBBRRB...,test,HG00260,ENSG00000214425.1,1
20012,test:HG01632:ENSG00000176681.9,BBRBBBRBBBBBBRBBBBRBBBBBBBBBBBBBBBBBRRBBBBRBBB...,test,HG01632,ENSG00000176681.9,9
20013,test:HG00173:ENSG00000238083.3,RMRRRRRRRRRRRRRRRRFRRRRFRRRRRRMRRRRMRRRRRRRMRR...,test,HG00173,ENSG00000238083.3,0
20014,test:HG00178:ENSG00000229450.2,RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRBBRRBRRRRRRRRRRR...,test,HG00178,ENSG00000229450.2,3


In [5]:
seq_df.shape

(5030, 6)

In [6]:
seq_df.split.unique()

array(['test'], dtype=object)

In [7]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('\nCUDA device: GPU\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
gc.collect()
torch.cuda.empty_cache()
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"


CUDA device: GPU



In [8]:
device = "cpu"

**1. Dataset and Dataloader**

Define Dataset: 

In [9]:
class SeqDataset(Dataset):

    def __init__(self, seq_df, transform, max_augm_shift=0, 
                 mode='train'):

        if input_params.dataset.endswith('.fa'):
            self.fasta = pysam.FastaFile(input_params.dataset)
        else:
            self.fasta = None

        self.seq_df = seq_df
        self.transform = transform
        self.max_augm_shift = max_augm_shift
        self.mode = mode

    def __len__(self):

        return 2*len(self.seq_df) # times two because returns both haplotypes 

    def __getitem__(self, idx):

        if self.fasta:
            seq = self.fasta.fetch(self.seq_df.iloc[idx].seq_name).upper()
        else:
            seq = self.seq_df.iloc[idx].seq.upper()

        shift = np.random.randint(self.max_augm_shift+1) #random shift at training, must be chunk_size-input_params.seq_len

        seq = seq[shift:shift+input_params.seq_len] #shift the sequence and limit its size
        seg_label = self.seq_df.iloc[idx].seg_label #label for segment-aware training
        
        seq1 = seq.replace('-','').replace('B','A').replace('F','A').replace('M','R') # father
        seq2 = seq.replace('-','').replace('B','A').replace('M','A').replace('F','R') # mother 

        masked_sequence1, target_labels_masked1, target_labels1, _, _ = self.transform(seq1)
        masked_sequence2, target_labels_masked2, target_labels2, _, _ = self.transform(seq2)

        masked_sequence = torch.vstack((masked_sequence1, masked_sequence2))
        seg_label = torch.vstack((torch.tensor(seg_label), torch.tensor(seg_label)))
        masked_sequence = (masked_sequence, seg_label)

        target_labels_masked = torch.vstack((target_labels_masked1, target_labels_masked2))
        target_labels = torch.vstack((target_labels1, target_labels2))
        seq = (seq1, seq2)
        return masked_sequence, target_labels_masked, target_labels, seq
        
        '''
        #for given genotype, randomly choose a haplotype for training/testing
        if np.random.rand()>0.5:
            seq = seq.replace('-','').replace('B','A').replace('F','A').replace('M','R')
        else:
            seq = seq.replace('-','').replace('B','A').replace('M','A').replace('F','R')

        #if input_params.masking == 'stratified_maf' and not input_params.test:
        #    #select mask for the sequence depending on sequence coordinates w.r.t. contig
        #    seg_name = self.seq_df.iloc[idx].seg_name
        #    seq_mask = meta.loc[seg_name].MASK.values
        #    masked_sequence, target_labels_masked, target_labels, _, _ = self.transform(seq, mask = seq_mask)
        #else:
        #    masked_sequence, target_labels_masked, target_labels, _, _ = self.transform(seq)

        masked_sequence, target_labels_masked, target_labels, _, _ = self.transform(seq)

        masked_sequence = (masked_sequence, seg_label)
        return masked_sequence, target_labels_masked, target_labels, seq
        '''

    def close(self):
        self.fasta.close()

In [10]:
def collate_fn(data): 
    #masked sequence
    masked_sequence = [x[0][0] for x in data]
    masked_sequence = [torch.stack(torch.split(d, 3)) for d in masked_sequence] 
    masked_sequence = torch.concat(masked_sequence)
    #seg labels
    seg_labels = [x[0][1] for x in data]
    seg_labels = torch.concat(seg_labels).flatten()
    # target labels masked
    target_labels_masked = [x[1] for x in data]
    target_labels_masked = torch.concat(target_labels_masked)
    # target labels 
    target_labels = [x[2] for x in data]
    target_labels = torch.concat(target_labels)
    #seq
    seqs = [x[3] for x in data]
    seqs = tuple(chain.from_iterable(seqs))
    return (masked_sequence, seg_labels),target_labels_masked, target_labels, seqs

In [11]:
def collate_fn_get_embeddings(data): 
    # masked sequence
    masked_sequence = [x[0][0] for x in data]
    #print(masked_sequence[0].shape) 
    masked_sequence = [torch.stack(torch.split(d, split_size_or_sections=50, dim = 0)) for d in masked_sequence]
    #print(masked_sequence[0].shape) 
    masked_sequence = torch.concat(masked_sequence)
    #print(masked_sequence.shape)
    # seg labels
    seg_labels = [x[0][1] for x in data]
    seg_labels = torch.concat(seg_labels).flatten()
    # target labels masked
    target_labels_masked = [x[1] for x in data]
    target_labels_masked = [torch.stack(torch.split(d, split_size_or_sections=50, dim = 0)) for d in target_labels_masked] 
    target_labels_masked = torch.concat(target_labels_masked)
    # target labels 
    target_labels = [x[2] for x in data]
    target_labels = [torch.stack(torch.split(d, split_size_or_sections=50, dim = 0)) for d in target_labels]
    target_labels = torch.concat(target_labels)
    #seq
    seqs = [x[3] for x in data]
    seqs = tuple(chain.from_iterable(seqs))     
    return (masked_sequence, seg_labels), target_labels_masked, target_labels, seqs

Create Dataset and Dataloader for the data: 

In [12]:
test_df = None 

if not input_params.test: #Train and Validate
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len,
                                                      mask_rate = input_params.mask_rate, split_mask = input_params.split_mask)

    #N_train = int(len(seq_df)*(1-input_params.val_fraction))
    if input_params.fold is not None:
        
        samples = seq_df.sample_id.unique()
        val_samples = samples[input_params.fold::input_params.Nfolds] 
        train_df = seq_df[~seq_df.sample_id.isin(val_samples)] 
        test_df = seq_df[seq_df.sample_id.isin(val_samples)]
        test_dataset = SeqDataset(test_df, transform = seq_transform, mode='eval')
        test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 0, collate_fn = collate_fn, shuffle = False)
    else:
        train_df = seq_df
        #train_df = seq_df[seq_df.split=='train']
        #test_df = seq_df[seq_df.split=='val']
  
    N_train = len(train_df)
    train_fold = np.repeat(list(range(input_params.train_splits)),repeats = N_train // input_params.train_splits + 1 )
    train_df['train_fold'] = train_fold[:N_train]
    # create training dataset & dataloader 
    train_dataset = SeqDataset(train_df, transform = seq_transform,  mode='train')
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = collate_fn, shuffle = False)

elif input_params.get_embeddings:
    if input_params.mask_at_test:
        seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)
    else:
        seq_transform = sequence_encoders.PlainOneHot(frame = 0, padding = 'none')
    # create test dataset & dataloader 
    test_dataset = SeqDataset(seq_df, transform = seq_transform, mode='eval')
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = collate_fn_get_embeddings, shuffle = False)

else: #Test
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len,
                                                      mask_rate=input_params.mask_rate, split_mask = input_params.split_mask)
    # create test dataset & dataloader 
    test_dataset = SeqDataset(seq_df, transform = seq_transform, mode='eval')
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = collate_fn, shuffle = False)

In [13]:
(masked_sequence, species_label), targets_masked, targets, _ = next(iter(test_dataloader))

In [14]:
next(iter(test_dataloader))

((tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 1., 1.,  ..., 1., 1., 1.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [1., 0., 1.,  ..., 1., 1., 1.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [1., 1., 0.,  ..., 1., 1., 1.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           ...,
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [1., 1., 1.,  ..., 0., 1., 1.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [1., 1., 1.,  ..., 1., 0., 1.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [1., 1., 1.,  ..., 1., 1., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],
  
  
          [[[0., 0., 1.,  ..., 1., 1., 1.],
            [0., 1., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[1., 0., 1., 

**2. Define model**

In [15]:
seg_encoder = SpecAdd(embed = True, encoder = 'label', Nsegments=seq_df.seg_label.nunique(), d_model = input_params.d_model)

model = DSSResNetEmb(d_input = 3, d_output = 3, d_model = input_params.d_model, n_layers = input_params.n_layers, 
                     dropout = input_params.dropout, embed_before = True, species_encoder = seg_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = input_params.learning_rate, weight_decay = input_params.weight_decay)

**3. Train model**

In [16]:
last_epoch = 0

if input_params.model_weight:

    if torch.cuda.is_available():
        #load on gpu
        model.load_state_dict(torch.load(input_params.model_weight))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight))
    else:
        #load on cpu
        model.load_state_dict(torch.load(input_params.model_weight, map_location=torch.device('cpu')))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight, map_location=torch.device('cpu')))

    last_epoch = int(input_params.model_weight.split('_')[-3]) #infer previous epoch from input_params.model_weight

weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

#lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
#        milestones=input_params.lr_sch_milestones, gamma=input_params.lr_sch_gamma, verbose=False) 



In [17]:
def metrics_to_str(metrics):
    loss, accuracy, masked_acc, masked_recall, masked_IQS = metrics
    return f'loss: {loss:.4}, acc: {accuracy:.4}, masked acc: {masked_acc:.4}, {misc.print_class_recall(masked_recall, "masked recall: ")}, masked IQS: {masked_IQS:.4}'

In [18]:
from IPython.display import clear_output

clear_output()

#from helpers.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        #if input_params.masking == 'stratified_maf':

        #    meta = get_random_mask()

        train_dataset.seq_df = train_df[train_df.train_fold == (epoch-1) % input_params.train_splits]
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device,
                            silent = False)
        
        print(f'epoch {epoch} - train, {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, weights_dir, epoch)

        if test_df is not None  and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics, *_ =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')
            
        #lr_scheduler.step()
else:

    print(f'EPOCH {last_epoch}: Test/Inference...')

    test_metrics, test_embeddings, motif_probas =  train_eval.model_eval(model, optimizer, test_dataloader, device, 
                                                          get_embeddings = input_params.get_embeddings, diploid=True,
                                                          silent = False)
    
    

    print(f'epoch {last_epoch} - test, {metrics_to_str(test_metrics)}')

    if input_params.get_embeddings:
        
        os.makedirs(input_params.output_dir, exist_ok = True)

        with open(input_params.output_dir + '/embeddings.pickle', 'wb') as f:
            #test_embeddings = np.vstack(test_embeddings)
            #np.save(f,test_embeddings)
            pickle.dump(test_embeddings,f)
            #pickle.dump(seq_df.seq_name.tolist(),f)
            
print()
print(f'peak GPU memory allocation: {round(torch.cuda.max_memory_allocated(device)/1024/1024)} Mb')
print('Done')

EPOCH 100: Test/Inference...


  0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               | 0/10060 [00:00<?, ?it/s]

  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]


UnboundLocalError: cannot access local variable 'embeddings' where it is not associated with a value

### FOR TESTING: 

In [43]:
(masked_sequence, species_label), targets_masked, targets, seq = next(iter(test_dataloader))

In [44]:
from helpers.metrics import MeanRecall, MaskedAccuracy, IQS
from helpers.misc import EMA, print_class_recall
from torch.nn.functional import log_softmax

temperature = None
criterion = torch.nn.CrossEntropyLoss(reduction = "mean")

accuracy = MaskedAccuracy().to(device)
masked_recall = MeanRecall(Nclasses=4).to(device)
masked_accuracy = MaskedAccuracy().to(device)
masked_IQS = IQS(Nclasses=4).to(device)

model.eval() #model to train mode
avg_loss = 0.
all_embeddings = []
motif_probas = []

with torch.no_grad():
    if True:
        # with batch size = 1, one batch contains the two sequences of one sample 
        # let's extract both sequences and get the predictions of both 
        # afterwards compute the scores for the combined predictions 
        #masked sequence
        masked_sequence1 = torch.split(masked_sequence, split_size_or_sections = 1, dim = 0)[0][0]
        masked_sequence2 = torch.split(masked_sequence, split_size_or_sections = 1, dim = 0)[1][0]
        # targets_masked
        targets_masked1 = torch.split(targets_masked, split_size_or_sections = 1, dim = 0)[0][0]
        targets_masked2 = torch.split(targets_masked, split_size_or_sections = 1, dim = 0)[1][0]
        # targets 
        targets1 = torch.split(targets, split_size_or_sections = 1, dim = 0)[0][0]
        targets2 = torch.split(targets, split_size_or_sections = 1, dim = 0)[1][0]
        species_label = species_label[0]
        species_label = species_label.tile((len(masked_sequence1),))

        masked_sequence1 = masked_sequence1.to(device)
        masked_sequence2 = masked_sequence2.to(device)
        targets_masked1 = targets_masked1.to(device)
        targets_masked2 = targets_masked2.to(device)
        targets1 = targets1.to(device) 
        targets2 = targets2.to(device)   
        species_label = species_label.long().to(device)

        logits1, embeddings1 = model(masked_sequence1, species_label)
        if temperature:
            logits1 /= temperature
        loss1 = criterion(logits1, targets_masked1)
        avg_loss += loss1.item()
        preds1 = torch.argmax(logits1, dim=1)

        logits2, embeddings2 = model(masked_sequence2, species_label)
        if temperature:
            logits2 /= temperature
        loss2 = criterion(logits2, targets_masked1)
        avg_loss += loss2.item()
        preds2 = torch.argmax(logits2, dim=1)

        # "notation":
        # 0 = R
        # 1 = M 
        # 2 = F
        # 3 = B 
        # -100 = masked 
         
        # combine preds 
        combined_preds = preds1+preds2
        combined_preds = torch.where(combined_preds==2, combined_preds +1, 0) # == 3 if both, 0 otherwise 
        temp = combined_preds + preds1 # == 4 if both, 1 if father 
        combined_preds = torch.where(temp==1, temp+1, combined_preds) # == 3 if both, 2 if father, otherwise 0 
        temp = combined_preds + preds2 # == 4 if both, 2 if father, 1 if mother 
        combined_preds = torch.where(temp==1, temp, combined_preds) # == 3 if both, 2 if father, 1 if mother, otherwise 0 

        # combine targets
        combined_targets = targets1+targets2
        combined_targets = torch.where(combined_targets==2, combined_targets +1, 0) # == 3 if both, 0 otherwise 
        temp = combined_targets + targets1 # == 4 if both, 1 if father 
        combined_targets = torch.where(temp==1, temp+1, combined_targets) # == 3 if both, 2 if father, otherwise 0 
        temp = combined_targets + targets2 # == 4 if both, 2 if father, 1 if mother 
        combined_targets = torch.where(temp==1, temp, combined_targets) # == 3 if both, 2 if father, 1 if mother, otherwise 0 

        # combine masked targets 
        combined_targets_masked = targets_masked1+targets_masked2
        combined_targets_masked = torch.where(combined_targets_masked==2, combined_targets_masked +1, 0) # == 3 if both, 0 otherwise 
        temp = combined_targets_masked + targets_masked1 # == 4 if both, 1 if father 
        combined_targets_masked = torch.where(temp==1, temp+1, combined_targets_masked) # == 3 if both, 2 if father, otherwise 0 
        temp = combined_targets_masked + targets_masked2 # == 4 if both, 2 if father, 1 if mother 
        combined_targets_masked = torch.where(temp==1, temp, combined_targets_masked) # == 3 if both, 2 if father, 1 if mother, otherwise 0 
        combined_targets_masked = torch.where(temp==-100, temp, combined_targets_masked)# == 3 if both, 2 if father, 1 if mother, -100 if masked, otherwise 0 

        accuracy.update(combined_preds, combined_targets)
        masked_recall.update(combined_preds, combined_targets_masked)
        masked_accuracy.update(combined_preds, combined_targets_masked)
        masked_IQS.update(combined_preds, combined_targets_masked)


In [45]:
accuracy.compute(), masked_accuracy.compute(), masked_recall.compute(), masked_IQS.compute()

(tensor(0.9735),
 tensor(0.9576),
 array([0.97356606, 0.91608393, 0.92177314, 0.98029196], dtype=float32),
 0.9402140974998474)

In [None]:
# 0 = R
# 1 = M 
# 2 = F
# 3 = B 

In [32]:
# combine predictions
combined_preds = preds1+preds2
combined_preds = torch.where(combined_preds==2, combined_preds +1, 0) # == 3 if both, 0 otherwise 
temp = combined_preds + preds1 # == 4 if both, 1 if father 
combined_preds = torch.where(temp==1, temp+1, combined_preds) # == 3 if both, 2 if father, otherwise 0 
temp = combined_preds + preds2 # == 4 if both, 2 if father, 1 if mother 
combined_preds = torch.where(temp==1, temp, combined_preds) # == 3 if both, 2 if father, 1 if mother, otherwise 0 
combined_preds

tensor([[2, 2, 2,  ..., 2, 0, 0],
        [2, 2, 2,  ..., 2, 0, 0],
        [2, 2, 2,  ..., 2, 0, 0],
        ...,
        [2, 2, 2,  ..., 2, 0, 0],
        [2, 2, 2,  ..., 1, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0]])

In [35]:
# combine targets
combined_targets = targets1+targets2
combined_targets = torch.where(combined_targets==2, combined_targets +1, 0) # == 3 if both, 0 otherwise 
temp = combined_targets + targets1 # == 4 if both, 1 if father 
combined_targets = torch.where(temp==1, temp+1, combined_targets) # == 3 if both, 2 if father, otherwise 0 
temp = combined_targets + targets2 # == 4 if both, 2 if father, 1 if mother 
combined_targets = torch.where(temp==1, temp, combined_targets) # == 3 if both, 2 if father, 1 if mother, otherwise 0 
combined_targets

tensor([[2, 3, 2,  ..., 2, 2, 2],
        [2, 3, 2,  ..., 2, 2, 2],
        [2, 3, 2,  ..., 2, 2, 2],
        ...,
        [2, 3, 2,  ..., 2, 2, 2],
        [2, 3, 2,  ..., 2, 2, 2],
        [2, 3, 2,  ..., 2, 2, 2]])

In [None]:
combined_targets_masked = targets_masked1+targets_masked2
combined_targets_masked= torch.where(combined_targets_masked==2, combined_targets_masked, 0)
# 1 = F
temp = combined_targets_masked + targets_masked1
combined_targets_masked = torch.where(temp==1, temp, combined_targets_masked)
combined_targets_masked = torch.where(temp==-100, temp, combined_targets_masked)

In [39]:
# combine masked targets
combined_targets_masked = targets_masked1+targets_masked2
combined_targets_masked = torch.where(combined_targets_masked==2, combined_targets_masked +1, 0) # == 3 if both, 0 otherwise 
temp = combined_targets_masked + targets_masked1 # == 4 if both, 1 if father 
combined_targets_masked = torch.where(temp==1, temp+1, combined_targets_masked) # == 3 if both, 2 if father, otherwise 0 
temp = combined_targets_masked + targets_masked2 # == 4 if both, 2 if father, 1 if mother 
combined_targets_masked = torch.where(temp==1, temp, combined_targets_masked) # == 3 if both, 2 if father, 1 if mother, otherwise 0 
combined_targets_masked = torch.where(temp==-100, temp, combined_targets_masked)# == 3 if both, 2 if father, 1 if mother, -100 if masked, otherwise 0 
combined_targets_masked

tensor([[   2, -100, -100,  ..., -100, -100, -100],
        [-100,    3, -100,  ..., -100, -100, -100],
        [-100, -100,    2,  ..., -100, -100, -100],
        ...,
        [-100, -100, -100,  ...,    2, -100, -100],
        [-100, -100, -100,  ..., -100,    2, -100],
        [-100, -100, -100,  ..., -100, -100,    2]])

In [37]:
targets_masked1

tensor([[   1, -100, -100,  ..., -100, -100, -100],
        [-100,    1, -100,  ..., -100, -100, -100],
        [-100, -100,    1,  ..., -100, -100, -100],
        ...,
        [-100, -100, -100,  ...,    1, -100, -100],
        [-100, -100, -100,  ..., -100,    1, -100],
        [-100, -100, -100,  ..., -100, -100,    1]])

In [38]:
targets_masked2

tensor([[   0, -100, -100,  ..., -100, -100, -100],
        [-100,    1, -100,  ..., -100, -100, -100],
        [-100, -100,    0,  ..., -100, -100, -100],
        ...,
        [-100, -100, -100,  ...,    0, -100, -100],
        [-100, -100, -100,  ..., -100,    0, -100],
        [-100, -100, -100,  ..., -100, -100,    0]])