<a href="https://colab.research.google.com/github/musophobia/KrittimBuddhi/blob/master/MutSpace.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#utils.py
import json
from os.path import join
import os
from random import sample
import torch
import h5py
from tqdm import tqdm
import numpy as np
#model.py
import torch
from torch import nn
from torch.optim import Adam
#train.py
import sys
import torch
import argparse
# from utils import *
#from model import MutSpace
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# utils.py
def get_seq_mapping():
    """convert mutation type to integer (index)"""
    #TODO add support for indel and make it more convenient to modify this function
    mut_mapping = {'C->A': 0, 'C->G': 1, 'C->T': 2, 'T->A': 3, 'T->C': 4, 'T->G': 5}
    nuk_mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    return mut_mapping, nuk_mapping


def get_rev_seq_mapping():
    """reverse operation for ger_seq_mapping"""
    mut_mapping = {0: 'C->A', 1: 'C->G', 2: 'C->T', 3: 'T->A', 4: 'T->C', 5: 'T->G'}
    nuk_mapping = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
    return mut_mapping, nuk_mapping


def index(seq):
    """convert an arbitary mutation pattern, ie the up- and downstream of mutation into an index
    # [up, down, mut]
    # [3, 1, 0, 2, 4] -> 1159
    the last element indicates the nucleotide mapping for the reference base
    the first half (except the last one) is the nucleotides mapping for the upstream bases
    the second half (except the last one) is the nucleotides mapping for the downstream bases
    """
    idx = 0
    for i, s in enumerate(seq):
        idx += (4 ** i) * s
    return idx


def rev_index(num, n_bits):
    # 1159 -> [3, 1, 0, 2, 4]
    # here n_bits = 5
    seq = []
    for i in range(n_bits):
        base = 4 ** (n_bits - i - 1)
        idx = num // base
        num = num - idx * base
        seq.insert(0, idx)
    return seq


class Dataset:
    def __init__(self, config, data_path):
        self.config = config
        self.data_path = data_path
        self.data = []

    def process(self):
        """
        This function should process the raw input into a format of
        [[id, a_features, b_featuers]]
        :return:
        """
        return None

    def sample(self, n_items, drop_id=None):
        if drop_id is not None:
            samples = [case for case in sample(self.data, n_items) if case[0] != drop_id]
        else:
            samples = sample(self.data, n_items)
        return samples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]


class MutationDatasetClassic: 
    def __init__(self, data_path, context_width):
        self.data_path = data_path
        self.mut_mapping, self.nuk_mapping = get_seq_mapping()
        self.context_width = context_width
        self.feature_dict = dict()
        self.feature_num = (6 * 4 ** (2 * self.context_width))
        self.data = self.process()

    def process(self):
        """
        This function process the raw input and convert the mutation into a patient to mutation pattern matrix () and a patient id to row of matrix, and a mutation pattern to column of matrix json
        """
        #
        data = []
        patient2row = dict()
        pattern2col = dict()
        #
        datalist = os.listdir(self.data_path)
        datalist = [".".join(name.split('.')[:-1]) for name in datalist if name.endswith('tsv')]
        header = json.load(open(join(self.data_path, 'meta.json')))
        print(datalist)
        # 
        matrix_list = []
        for cname in datalist:
            print(f'Loading {cname}...')
            with open(os.path.join(self.data_path, f'{cname}.tsv'), 'r') as fin:
                for i, line in enumerate(fin.readlines()):
                    line = line.strip().split('\t')
                    if i == 0: # skip header
                        continue
                    # uid is patient id or sample id
                    uname = line[header['uid']]
                    if patient2row.get(uname, None) is None:
                        patient2row[uname] = len(patient2row)
                        matrix_list.append(np.zeros(self.feature_num))
                    uid = patient2row[uname]
                    # get pattern
                    up = line[header['upstream']]
                    down = line[header['downstream']]
                    up = up[-self.context_width:]
                    down = down[:self.context_width]
                    var_type = line[header['var_type']]
                    pattern = up + '(' + var_type + ')' + down
                    idx = self.decompose(up, down, var_type)
                    if pattern2col.get(pattern, None) is None:
                        pattern2col[pattern] = idx
                    # update count
                    matrix_list[uid][idx] += 1
        # 
        matrix = np.stack(matrix_list)
        json.dump(patient2row, open(join(self.data_path, 'patient2row.json'), 'w'))
        json.dump(pattern2col, open(join(self.data_path, 'pattern2col.json'), 'w'))
        return matrix

    def decompose(self, upstream, downstream, mutation):
        mutation = [self.mut_mapping[mutation]]
        up = [self.nuk_mapping[s] for s in upstream]
        down = [self.nuk_mapping[s] for s in downstream]
        return index(up + down + mutation) 

    def __len__(self):
        return len(matrix) 

    def __getitem__(self, item):
        return self.data[item]


class MutationDataset:
    def __init__(self, config, data_path):
        self.config = config
        self.data_path = data_path
        self.mut_mapping, self.nuk_mapping = get_seq_mapping()
        self.ring_width = self.config.ring_width
        self.ring_num = self.config.ring_num
        self.feature_dict = dict()
        self.feature_num = (6 * 4 ** (2 * self.config.ring_width)) * self.config.ring_num
        self.data = self.process()
        self.num_fa = len(self.data[-1][1])
        self.num_fb = len(self.data[-1][2])

    def get_feature_id(self, f):
        """
        assign feature index to a new feature
        by default self.feature_num is the number of rings
        note f is a string so that only categorical features are supported
        """
        if f not in self.feature_dict:
            self.feature_dict[f] = self.feature_num
            self.feature_num += 1
        return self.feature_dict[f]

    def process(self, fresh = True):
        data = []
        patient_mapping = dict()
        print(open(join(self.data_path, 'meta.json')))
        header = json.load(open(join(self.data_path, 'meta.json')))
        datalist = os.listdir(self.config.data_path)
        datalist = [".".join(name.split('.')[:-1]) for name in datalist if name.endswith('tsv')]
        print(datalist)
        for cname in datalist:
            print(f'Loading {cname}...')
            fin = open(os.path.join(self.data_path, f'{cname}.tsv'), 'r')
            for i, line in enumerate(fin.readlines()):
                line = line.strip().split('\t')
                line.append(cname)
                if i == 0:
                    continue

                # uid is patient id or sample id
                uid = self.get_feature_id(line[header['uid']])
                # a_features: categorical feature for each patient/sample
                # b_features: cetegorical feature for each project/data
                a_features, b_features = [], []

                # convert mutation and its context into index
                rings = self.decompose(line[header['upstream']], line[header['downstream']], line[header['var_type']])
                a_features.extend(rings)

                for col in header['a']:
                    a_features.append(self.get_feature_id(line[col]))

                for col in header['b']:
                    b_features.append(self.get_feature_id(line[col]))

                case = [uid, a_features, b_features]
                data.append(case)

                pname = line[header['uid']]
                if pname not in patient_mapping:
                    patient_mapping[pname] = case

            if self.config.debug:
                break
        if fresh:
            json.dump(self.feature_dict, open(join(self.config.ckpt_path, 'feature_dict.json'), 'w'))
            json.dump(patient_mapping, open(join(self.config.ckpt_path, 'patient_mapping.json'), 'w'))
        return data

    def decompose(self, upstream, downstream, mutation):
        rings = []
        w = self.ring_width
        mutation = [self.mut_mapping[mutation]]
        base = 6 * 4 ** (2 * self.ring_width)
        for i in range(self.ring_num):
            if i == 0:
                up = upstream[-w:]
            else:
                up = upstream[-(i + 1) * w:-i * w]
            down = downstream[i * w:(i + 1) * w]
            up = [self.nuk_mapping[s] for s in up]
            down = [self.nuk_mapping[s] for s in down]
            rings.append(index(up + down + mutation) + i * base)
        return rings

    def sample(self, n_items, drop_id=None):
        if drop_id is not None:
            samples = [case for case in sample(self.data, n_items) if case[0] != drop_id]
        else:
            samples = sample(self.data, n_items)
        return samples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

    def save_hdf5(self, path):
        f = h5py.File(path, 'w')
        f['feature_num'] = self.feature_num
        f['num_fa'] = self.num_fa
        f['num_fb'] = self.num_fb
        f['total'] = len(self.data)
        for i, item in enumerate(tqdm(self.data)):
            f.create_group(str(i))
            f[str(i)]['uid'] = item[0]
            f[str(i)]['a_features'] = item[1]
            f[str(i)]['b_features'] = item[2]
        f.close()


class MutationDatasetH5PY:
    def __init__(self, config, data_path, patient_mapping_path):
        self.data = h5py.File(data_path, 'r')
        self.config = config
        self.feature_num = self.data['feature_num'][()]
        self.num_fa = self.data['num_fa'][()]
        self.num_fb = self.data['num_fb'][()]
        self.total = self.data['total'][()]
        self.patient_mapping = json.load(open(patient_mapping_path))

    def __getitem__(self, item):
        uid = self.data[str(item)]['uid'][()]
        fa = list(self.data[str(item)]['a_features'])
        fb = list(self.data[str(item)]['b_features'])
        return [uid, fa, fb]

    def __len__(self):
        return self.total

    def sample(self, n_negative, drop_id):
        keys = sample(self.patient_mapping.keys(), n_negative)
        samples = [self.patient_mapping[k] for k in keys if self.patient_mapping[k][0] != drop_id]
        return samples


class MyCollator:
    def __init__(self, config, dataset):
        self.dataset = dataset
        self.num_fa = dataset.num_fa
        self.num_fb = dataset.num_fb
        self.n_negative = config.n_negative

    def __call__(self, batch):
        N = len(batch)
        pos_a_features = [] 
        pos_b_features = []
        neg_b_features = []
        neg_mask = []
        for i, pos_case in enumerate(batch):
            neg_sample = self.dataset.sample(self.n_negative, drop_id=pos_case[0])
            pos_a_features.append(pos_case[1])
            pos_b_features.append(pos_case[2])
            neg_b_features.append([])
            neg_mask.append([])
            for j, neg_case in enumerate(neg_sample):
                neg_mask[i].append(1)
                neg_b_features[i].append(neg_case[2])
            pad_num = self.n_negative - len(neg_mask[i])
            neg_mask[i].extend([0] * pad_num)
            neg_b_features[i].extend([[0] * self.num_fb] * pad_num)

        return {'pos_a': torch.LongTensor(pos_a_features).cuda(), 
                'pos_b': torch.LongTensor(pos_b_features).cuda(), 
                'neg_b': torch.LongTensor(neg_b_features).cuda(),
                'neg_mask': torch.FloatTensor(neg_mask).cuda()}

In [None]:
#model.py

class InnerProductSimilarity(nn.Module):
    def __init__(self, temp):
        super(InnerProductSimilarity, self).__init__()
        self.temp = temp

    def forward(self, a, b):
        d = a.shape[1]
        a = a.unsqueeze(1)  # N x 1 x dim
        if len(b.shape) == 2:
            b = b.unsqueeze(2)  # N x dim x 1
            similarity = torch.bmm(a, b).squeeze()
        elif len(b.shape) == 3:
            # N x neg x dim
            similarity = torch.sum(a * b, dim=(-1,))
        else:
            assert False
        return similarity / pow(d, self.temp)   # [N] or [N x n_neg]


class MarginRankingLoss(nn.Module):
    def __init__(self, margin=1., aggregate=torch.mean):
        super(MarginRankingLoss, self).__init__()
        self.margin = margin
        self.aggregate = aggregate

    def forward(self, positive_similarity, negative_similarity, negative_mask):
        """
        :param positive_similarity: [N]
        :param negative_similarity: [N x K]
        :param negative_mask: [N x K]
        :return:
        """
        positive_similarity = positive_similarity.unsqueeze(1)
        return self.aggregate(
            torch.clamp((self.margin - positive_similarity + negative_similarity) * negative_mask, min=0))


class MutSpace(nn.Module):
    def __init__(self, config, n_features):
        super(MutSpace, self).__init__()
        self.config = config
        self.embedding = nn.Embedding(n_features, config.emb_dim, max_norm=config.max_norm)
        self.similarity = InnerProductSimilarity(config.temp)
        self.loss = MarginRankingLoss(margin=config.margin)
        self.optimizer = Adam(self.parameters(), lr=config.lr)

    def forward(self, batch):
        neg_mask = batch['neg_mask']            # [N x n_neg]
        pos_a = self.embedding(batch['pos_a'])  # [N x n_a x d]
        pos_b = self.embedding(batch['pos_b'])  # [N x n_b x d]
        neg_b = self.embedding(batch['neg_b'])  # [N x n_neg x n_b x d]
        pos_a = pos_a.sum(dim=1) / pow(pos_a.shape[1], 0.5)
        pos_b = pos_b.sum(dim=1) / pow(pos_b.shape[1], 0.5)
        neg_b = neg_b.sum(dim=2) / pow(neg_b.shape[2], 0.5)
        pos_score = self.similarity(pos_a, pos_b)
        neg_score = self.similarity(pos_a, neg_b)
        return self.loss(pos_score, neg_score, neg_mask), pos_score, neg_score

    def train_batch(self, batch):
        self.train()
        loss, pos_score, neg_score = self.forward(batch)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        return loss.item(), pos_score.mean().item(), neg_score.mean().item()

In [None]:
#train.py

def save_settings():
    '''
    Save parameters into a json file
    '''
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    np.random.seed(config.seed)
    os.makedirs(config.ckpt_path, exist_ok=True)
    setting_fn = os.path.join(config.ckpt_path, f'setting.json')
    json.dump(config.__dict__, open(setting_fn, 'w'))

def parse_arg():
    ''' 
    parse parameters
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, help="Name or ID of this run. A folder with this name will be created in the directory where this script is being executed")
    parser.add_argument('--data_path', type=str, help="Directory of mutation data")
    parser.add_argument('--ring_num', type=int, default=6, help="Number of sub-components for the sequence context")
    parser.add_argument('--ring_width', type=int, default=1, help="The width of sequence sub-component")
    parser.add_argument('--margin', type=float, default=1.0, help="The constant margin parameter in the hinge-loss function, default is 1.0")
    parser.add_argument('--max_norm', type=float, default=10.0, help="The Max norm of embeddings, default is 10.0")
    parser.add_argument('--emb_dim', type=int, default=200, help="Dimension of embedded vector, default is 200")
    parser.add_argument('--epochs', type=int, default=50, help="Number of epochs before training is stopped")
    parser.add_argument('--batch_size', type=int, default=4096, help="Batch size, default is 4096")
    parser.add_argument('--lr', type=float, default=1e-3, help="Learning rate, default is 1e-3")
    parser.add_argument('--n_negative', type=int, default=15, help="Number of negative samples per positive sample generated through negative sampling, default is 15")
    #parser.add_argument('--verbose_step', type=int, default=500, help="verbose level")
    parser.add_argument('--debug', action='store_true', help="")
    parser.add_argument('--temp', type=float, default=1.0, help="Normalization parameter for calculation of similarity")
    parser.add_argument('--seed', type=int, default=888, help="Random seed, default is 888")
    # if len(sys.argv) < 2:
    #     print(parser.print_help())
    #     exit(1)
    return parser.parse_args(args=[])

if __name__ == "__main__":
    # set parameters
    config = parse_arg()
    config.data_path = '/content/drive/My Drive/MutSpace/demo_data/ICGC-BRCA-EU/'
    config.name = 'ICGC-BRCA-EU'
    config.temp = 1.0
    config.seed = 88
    config.ring_num = 5
    config.ring_width = 1
    # set check point output folder
    config.ckpt_path = f'/content/drive/My Drive/MutSpace/demo_data/ICGC-BRCA-EU/ckpt/{config.name}'
    # save parameters
    save_settings()
    record_fout = open(os.path.join(config.ckpt_path, 'record.txt'), 'a')
    # load mutation data
    dataset = MutationDataset(config, config.data_path)
    patient_mapping_fn = join(config.ckpt_path, 'patient_mapping.json')
    print(f'Total mutation = {len(dataset)}')
    mycollator = MyCollator(config, dataset)
    dataloader = DataLoader(dataset,
                            batch_size=config.batch_size,
                            shuffle=True,
                            num_workers=0,
                            collate_fn=mycollator)
    # training process
    model = MutSpace(config=config, n_features=dataset.feature_num)
    model = model.cuda()
    config.verbose_step = len(dataloader) // 2
    range_loss = range_ps = range_ns = 0
    for epc in range(config.epochs):
        for step, batch in enumerate(tqdm(dataloader)):
            loss, ps, ns = model.train_batch(batch)
            range_loss += loss
            range_ps += ps
            range_ns += ns
            if (step + 1) % config.verbose_step == 0:
                log = f'loss = {range_loss / config.verbose_step} pos score = {range_ps / config.verbose_step}, neg score = {range_ns / config.verbose_step}\n'
                print(log)
                record_fout.writelines(log)
                range_loss = range_ps = range_ns = 0

        torch.save(model.state_dict(), f'/content/drive/My Drive/MutSpace/demo_data/ICGC-BRCA-EU/ckpt/{config.name}/{epc}.pth')
    # finish
    record_fout.close()

<_io.TextIOWrapper name='/content/drive/My Drive/MutSpace/demo_data/ICGC-BRCA-EU/meta.json' mode='r' encoding='UTF-8'>
['ICGC-BRCA-EU']
Loading ICGC-BRCA-EU...
Total mutation = 3263296


 50%|████▉     | 398/797 [01:29<01:27,  4.58it/s]

loss = 0.9881439674739263 pos score = 0.006210270961043109, neg score = -0.00013228240182893867



100%|██████████| 797/797 [02:59<00:00,  4.45it/s]

loss = 0.9467039660892295 pos score = 0.0481834588946784, neg score = -0.000459961990072119




 50%|████▉     | 398/797 [01:29<01:24,  4.71it/s]

loss = 0.8794479450987811 pos score = 0.11599738384835684, neg score = -0.0041019044820589



100%|██████████| 797/797 [02:56<00:00,  4.52it/s]

loss = 0.8357611091891725 pos score = 0.16238801873958292, neg score = -0.0037542748797438695




 50%|████▉     | 398/797 [01:29<01:28,  4.53it/s]

loss = 0.8236102268623946 pos score = 0.1836144214299456, neg score = -0.0019444895643675236



100%|██████████| 797/797 [02:58<00:00,  4.47it/s]

loss = 0.8158629081357065 pos score = 0.19163842192247285, neg score = -0.0011782654631050074




 50%|████▉     | 398/797 [01:28<01:25,  4.67it/s]

loss = 0.8136562538805918 pos score = 0.19783356149292472, neg score = -0.0010225758165904552



100%|██████████| 797/797 [02:56<00:00,  4.53it/s]

loss = 0.8085519830186163 pos score = 0.20144028898010302, neg score = -0.0008529384082002841




 50%|████▉     | 398/797 [01:29<01:27,  4.57it/s]

loss = 0.808043666371149 pos score = 0.20548102559156753, neg score = -0.000734069426203887



100%|██████████| 797/797 [02:59<00:00,  4.44it/s]

loss = 0.8055650955768087 pos score = 0.20603597044345723, neg score = -0.0005540207148670794




 50%|████▉     | 398/797 [01:27<01:26,  4.62it/s]

loss = 0.8059293575322808 pos score = 0.2089616497557367, neg score = -0.0002849304542979009



100%|██████████| 797/797 [02:54<00:00,  4.57it/s]

loss = 0.8030204435988287 pos score = 0.20952454497616493, neg score = -0.00041631293078897206




 50%|████▉     | 398/797 [01:29<01:32,  4.31it/s]

loss = 0.8043813467325278 pos score = 0.21127173787535136, neg score = -0.00010443543143178911



100%|█████████▉| 796/797 [03:00<00:00,  4.35it/s]

loss = 0.8018079955673697 pos score = 0.2117063794378659, neg score = 0.00015227208724159186



100%|██████████| 797/797 [03:00<00:00,  4.42it/s]
 50%|████▉     | 398/797 [01:32<01:29,  4.48it/s]

loss = 0.8028777395960075 pos score = 0.21365729045478543, neg score = 0.0004745114011776404



100%|██████████| 797/797 [03:02<00:00,  4.36it/s]

loss = 0.8015383171975313 pos score = 0.2125897425502988, neg score = 0.0005644300729421335




 50%|████▉     | 398/797 [01:29<01:29,  4.46it/s]

loss = 0.8027455197207293 pos score = 0.2140644175847571, neg score = 0.0006055943597254266



100%|██████████| 797/797 [03:00<00:00,  4.43it/s]

loss = 0.8003804957746861 pos score = 0.2143590395249913, neg score = 0.0010668325123720462




 50%|████▉     | 398/797 [01:31<01:30,  4.39it/s]

loss = 0.8022009236129684 pos score = 0.21520494774508117, neg score = 0.0011535191141770133



100%|██████████| 797/797 [03:02<00:00,  4.37it/s]

loss = 0.8002929554213232 pos score = 0.21497388434919282, neg score = 0.0015412106182469158




 50%|████▉     | 398/797 [01:30<01:28,  4.50it/s]

loss = 0.8019509898058733 pos score = 0.2160345739006397, neg score = 0.0016900239245867521



100%|█████████▉| 796/797 [02:58<00:00,  4.39it/s]

loss = 0.7997606330780528 pos score = 0.21561455340990468, neg score = 0.0015350154329055168



100%|██████████| 797/797 [02:59<00:00,  4.45it/s]
 50%|████▉     | 398/797 [01:30<01:27,  4.56it/s]

loss = 0.801459159833103 pos score = 0.2166059962918411, neg score = 0.0016373528082142606



100%|█████████▉| 796/797 [02:59<00:00,  4.60it/s]

loss = 0.7997232955604342 pos score = 0.21602018458310085, neg score = 0.0018473258227203217



100%|██████████| 797/797 [02:59<00:00,  4.43it/s]
 50%|████▉     | 398/797 [01:27<01:28,  4.52it/s]

loss = 0.801801706229023 pos score = 0.2169882079314946, neg score = 0.0023241101398571344



100%|██████████| 797/797 [02:56<00:00,  4.52it/s]

loss = 0.7990849795353473 pos score = 0.2169408110443072, neg score = 0.0020245956565922503




 50%|████▉     | 398/797 [01:30<01:30,  4.42it/s]

loss = 0.8016031026540689 pos score = 0.21728889865042575, neg score = 0.0022421636315368746



100%|█████████▉| 796/797 [03:02<00:00,  4.34it/s]

loss = 0.7991185432402932 pos score = 0.21737993265216674, neg score = 0.0023267690493819133



100%|██████████| 797/797 [03:02<00:00,  4.36it/s]
 50%|████▉     | 398/797 [01:28<01:27,  4.57it/s]

loss = 0.8010173997088293 pos score = 0.2180633953393404, neg score = 0.0022979738954421207



100%|██████████| 797/797 [02:56<00:00,  4.51it/s]

loss = 0.7992008072347497 pos score = 0.21738182850668777, neg score = 0.002357164470397641




 50%|████▉     | 398/797 [01:28<01:26,  4.60it/s]

loss = 0.8013134950668968 pos score = 0.2178856379647351, neg score = 0.0023338391027246143



100%|█████████▉| 796/797 [02:58<00:00,  4.43it/s]

loss = 0.7987029354775971 pos score = 0.21823594618083245, neg score = 0.002663034666306999



100%|██████████| 797/797 [02:58<00:00,  4.47it/s]
 50%|████▉     | 398/797 [01:26<01:26,  4.63it/s]

loss = 0.8006298808596242 pos score = 0.2187803026344908, neg score = 0.0024984231142084422



100%|██████████| 797/797 [02:52<00:00,  4.61it/s]

loss = 0.7989495967201252 pos score = 0.21808351471495988, neg score = 0.0026269032792565006




 50%|████▉     | 398/797 [01:27<01:25,  4.67it/s]

loss = 0.8012887941832518 pos score = 0.21836962329981915, neg score = 0.0027008700988919564



100%|██████████| 797/797 [02:54<00:00,  4.56it/s]

loss = 0.7984128025308925 pos score = 0.21869917842910516, neg score = 0.002657471293869505




 50%|████▉     | 398/797 [01:29<01:27,  4.57it/s]

loss = 0.8007738647149436 pos score = 0.21907363599868276, neg score = 0.002910209370059616



100%|██████████| 797/797 [02:57<00:00,  4.50it/s]

loss = 0.7986699529928178 pos score = 0.21870783873688635, neg score = 0.002868477717131566




 50%|████▉     | 398/797 [01:29<01:29,  4.47it/s]

loss = 0.800710377531435 pos score = 0.2192627630715993, neg score = 0.002901472286944124



100%|██████████| 797/797 [02:56<00:00,  4.50it/s]

loss = 0.7987354048831978 pos score = 0.21870273624982067, neg score = 0.0029451023931780593




 50%|████▉     | 398/797 [01:30<01:24,  4.72it/s]

loss = 0.8004965535060844 pos score = 0.21972951060862997, neg score = 0.0031313115673583525



100%|██████████| 797/797 [02:56<00:00,  4.50it/s]

loss = 0.7988527875449789 pos score = 0.218776831525055, neg score = 0.0030842948661750106




 50%|████▉     | 398/797 [01:28<01:29,  4.47it/s]

loss = 0.8007755370894868 pos score = 0.2193662143831876, neg score = 0.0030911863657719957



100%|██████████| 797/797 [02:57<00:00,  4.50it/s]

loss = 0.7983852990308599 pos score = 0.2192399200407704, neg score = 0.0030990417631782197




 50%|████▉     | 398/797 [01:29<01:30,  4.42it/s]

loss = 0.8005979052141085 pos score = 0.21997234689530415, neg score = 0.0033911521698899005



100%|██████████| 797/797 [02:58<00:00,  4.67it/s]

loss = 0.7984265211838574 pos score = 0.2192290166215082, neg score = 0.0030815832014466182



100%|██████████| 797/797 [02:58<00:00,  4.46it/s]
 50%|████▉     | 398/797 [01:29<01:27,  4.55it/s]

loss = 0.8001926198377082 pos score = 0.22030982981674635, neg score = 0.003233461086468515



100%|██████████| 797/797 [02:56<00:00,  4.52it/s]

loss = 0.7987762935197533 pos score = 0.21910006197253665, neg score = 0.003250683099483225




 50%|████▉     | 398/797 [01:29<01:26,  4.62it/s]

loss = 0.79998866112987 pos score = 0.2205376087421149, neg score = 0.0033231002445090916



100%|██████████| 797/797 [02:57<00:00,  4.50it/s]

loss = 0.798754798557291 pos score = 0.21922116752844958, neg score = 0.003374818206263704




 50%|████▉     | 398/797 [01:25<01:26,  4.60it/s]

loss = 0.8003812222624543 pos score = 0.2202545007194706, neg score = 0.003482888350137771



100%|██████████| 797/797 [02:53<00:00,  4.58it/s]

loss = 0.7982194817545426 pos score = 0.21976765013070562, neg score = 0.003422915490133603




 50%|████▉     | 398/797 [01:28<01:29,  4.45it/s]

loss = 0.8001084245329526 pos score = 0.22047404042590205, neg score = 0.003414933151873642



100%|██████████| 797/797 [02:58<00:00,  4.47it/s]

loss = 0.7982374420717134 pos score = 0.2197920438557414, neg score = 0.0034219032654733044




 50%|████▉     | 398/797 [01:28<04:20,  1.53it/s]

loss = 0.800070718005674 pos score = 0.2208057404238375, neg score = 0.0036493567420565975



100%|██████████| 797/797 [02:57<00:00,  4.50it/s]

loss = 0.7984141533698269 pos score = 0.21988893783272212, neg score = 0.003685554858213491




 50%|████▉     | 398/797 [01:30<01:26,  4.59it/s]

loss = 0.8001997866223206 pos score = 0.22078476393192856, neg score = 0.0037942534668195436



100%|██████████| 797/797 [02:59<00:00,  4.43it/s]

loss = 0.7981135440831209 pos score = 0.2202586199919782, neg score = 0.003751760649852921




 50%|████▉     | 398/797 [01:27<01:29,  4.46it/s]

loss = 0.7999962117504235 pos score = 0.22094241499751058, neg score = 0.0037324383964613245



100%|██████████| 797/797 [02:55<00:00,  4.54it/s]

loss = 0.7983230491678919 pos score = 0.21999048640081031, neg score = 0.003711591641396223




 50%|████▉     | 398/797 [01:29<01:29,  4.44it/s]

loss = 0.800183484901735 pos score = 0.22081860821301014, neg score = 0.003792060367179339



100%|██████████| 797/797 [02:59<00:00,  4.44it/s]

loss = 0.7979270281204626 pos score = 0.22048804399805452, neg score = 0.0038175714093412623




 50%|████▉     | 398/797 [01:28<01:29,  4.46it/s]

loss = 0.7998599907261642 pos score = 0.2213234283292114, neg score = 0.00396000529380715



100%|██████████| 797/797 [02:55<00:00,  4.62it/s]

loss = 0.798205723415068 pos score = 0.22021130120484672, neg score = 0.0038241125500620334



100%|██████████| 797/797 [02:55<00:00,  4.53it/s]
 50%|████▉     | 398/797 [01:26<01:26,  4.63it/s]

loss = 0.800089677524327 pos score = 0.22096085177743854, neg score = 0.0038993597893729553



100%|██████████| 797/797 [02:54<00:00,  4.56it/s]

loss = 0.797891985381668 pos score = 0.22073300122915201, neg score = 0.004009734965366084




 50%|████▉     | 398/797 [01:28<01:29,  4.48it/s]

loss = 0.7996224626224844 pos score = 0.22163268556846447, neg score = 0.004070628619983973



100%|██████████| 797/797 [02:57<00:00,  4.66it/s]

loss = 0.7982568023492344 pos score = 0.22032501260239876, neg score = 0.003925813416572055



100%|██████████| 797/797 [02:57<00:00,  4.50it/s]
 50%|████▉     | 398/797 [01:28<01:26,  4.60it/s]

loss = 0.7997484536626231 pos score = 0.22168309945407225, neg score = 0.004273238324979379



100%|██████████| 797/797 [02:56<00:00,  4.52it/s]

loss = 0.798255494192018 pos score = 0.22062003238117275, neg score = 0.004292073144107856




 50%|████▉     | 398/797 [01:28<01:29,  4.48it/s]

loss = 0.7999715899402772 pos score = 0.22158228635937724, neg score = 0.004343314582102314



100%|██████████| 797/797 [02:58<00:00,  4.46it/s]

loss = 0.7979931256279873 pos score = 0.22091993891713607, neg score = 0.0042848233759323374




 50%|████▉     | 398/797 [01:29<01:28,  4.51it/s]

loss = 0.7996497091336466 pos score = 0.22192309662025778, neg score = 0.004348543589126783



100%|██████████| 797/797 [02:57<00:00,  4.49it/s]

loss = 0.7982835715739571 pos score = 0.22092220378131722, neg score = 0.004614960515571502




 50%|████▉     | 398/797 [01:29<01:30,  4.41it/s]

loss = 0.7998754501941815 pos score = 0.2218373569112327, neg score = 0.004616545603819567



100%|██████████| 797/797 [02:58<00:00,  4.47it/s]

loss = 0.7982140361663684 pos score = 0.22090083259583718, neg score = 0.0045259868587737995




 50%|████▉     | 398/797 [01:28<01:30,  4.41it/s]

loss = 0.8000681097783036 pos score = 0.22178268290344794, neg score = 0.004760858831100427



100%|██████████| 797/797 [02:56<00:00,  4.51it/s]

loss = 0.7977927316672838 pos score = 0.2214361540785986, neg score = 0.004669278073255847




 50%|████▉     | 398/797 [01:32<01:33,  4.29it/s]

loss = 0.7994256850762583 pos score = 0.22235701206940503, neg score = 0.00461319839522177



100%|██████████| 797/797 [03:02<00:00,  4.38it/s]

loss = 0.7982562763906603 pos score = 0.22104544031560122, neg score = 0.0047059473688009276




 50%|████▉     | 398/797 [01:29<01:29,  4.47it/s]

loss = 0.8000930370996945 pos score = 0.2217082115423739, neg score = 0.00464082188616099



100%|██████████| 797/797 [02:57<00:00,  4.48it/s]

loss = 0.7975552249793432 pos score = 0.22190510829788956, neg score = 0.004906925350220245




 50%|████▉     | 398/797 [01:29<01:31,  4.38it/s]

loss = 0.7997351374158908 pos score = 0.2220741538471313, neg score = 0.004694142456662747



100%|██████████| 797/797 [02:59<00:00,  4.43it/s]

loss = 0.7979710224884838 pos score = 0.2214535169265977, neg score = 0.004885009114556152




 50%|████▉     | 398/797 [01:28<01:29,  4.48it/s]

loss = 0.7997257718488798 pos score = 0.22223680455181466, neg score = 0.004848483295739896



100%|██████████| 797/797 [02:57<00:00,  4.50it/s]

loss = 0.7979699496048779 pos score = 0.22162832372152624, neg score = 0.005077835267111916




 50%|████▉     | 398/797 [01:31<01:30,  4.43it/s]

loss = 0.7998185395894937 pos score = 0.22238690319971824, neg score = 0.0050663381718924995



100%|██████████| 797/797 [03:02<00:00,  4.38it/s]

loss = 0.7978901548601275 pos score = 0.22168168534136298, neg score = 0.005031526732417473




 50%|████▉     | 398/797 [01:30<01:27,  4.57it/s]

loss = 0.7999973123397061 pos score = 0.22236037426557972, neg score = 0.005327247041354917



100%|██████████| 797/797 [03:01<00:00,  4.39it/s]

loss = 0.7977847340118945 pos score = 0.22198324294845065, neg score = 0.005208075841268574




 50%|████▉     | 398/797 [01:29<01:32,  4.33it/s]

loss = 0.7994854814143636 pos score = 0.22275635701178306, neg score = 0.00509870747871663



100%|██████████| 797/797 [03:00<00:00,  4.41it/s]

loss = 0.7978832859489786 pos score = 0.22174860836572982, neg score = 0.005074528255058406




 50%|████▉     | 398/797 [01:30<01:29,  4.46it/s]

loss = 0.7999612100459822 pos score = 0.22273078388604686, neg score = 0.005688663991779651



100%|█████████▉| 796/797 [02:59<00:00,  4.55it/s]

loss = 0.7976985570773407 pos score = 0.22210500217113063, neg score = 0.00529089727225351



100%|██████████| 797/797 [02:59<00:00,  4.45it/s]
 50%|████▉     | 398/797 [01:28<01:27,  4.57it/s]

loss = 0.799666970069684 pos score = 0.222695174603606, neg score = 0.005331413344249461



100%|██████████| 797/797 [02:55<00:00,  4.55it/s]

loss = 0.7976304929160593 pos score = 0.22222842208704158, neg score = 0.005383261164724763




 50%|████▉     | 398/797 [01:28<01:26,  4.64it/s]

loss = 0.7996884807569897 pos score = 0.22294337272494283, neg score = 0.0055253249662548035



100%|██████████| 797/797 [02:57<00:00,  4.48it/s]

loss = 0.7978774753946755 pos score = 0.2220696328797532, neg score = 0.005455389216045139




 50%|████▉     | 398/797 [01:27<01:28,  4.51it/s]

loss = 0.8002134387816616 pos score = 0.22266818141218406, neg score = 0.005895932459222865



100%|██████████| 797/797 [02:56<00:00,  4.51it/s]

loss = 0.7974525884767274 pos score = 0.22261438556202692, neg score = 0.005572349517609124






In [None]:
# with open('/content/drive/My Drive/MutSpace/demo_data/ICGC-BRCA-EU/meta.json', 'w') as f:
#   print(f.length())