# Train Test Split
Version 2 ran Oct 28.

Choose 20% of genes to be the unseen test set.   
For each test gene, remove all its transcripts from the training set.  

The task is simplified by our upstream data processing.   
In ProcessGenCode, we moved sequences to csv files and filtered:

* Retained only those genes for which there is at least one RCI value in LncAtlas.
* Retained only those non-coding transcripts with gene_type=transcript_type=lncRNA.
* Retained only those coding transcripts with gene_type=transcript_type=protein_coding.

In [1]:
from random import Random
from datetime import datetime
print(datetime.now())

TEST_PORTION = 0.2
# Inputs (generated by ProcessGenCode notebook)
GENCODE_DIR = '/Users/jasonmiller/WVU/Localization/GenCode/'
NONCODING_ALL = 'gencode.v42.lncRNA_transcripts.csv'
CODING_ALL = 'gencode.v42.pc_transcripts.csv'
ATLAS_DIR='/Users/jasonmiller/WVU/Localization/LncAtlas/'
ATLAS_DATA='lncATLAS_all_data_RCI.csv'
# Outputs
SPLIT_DIR = '/Users/jasonmiller/WVU/Localization/TrainTest/'
# Output gene lists
CODING_TEST_ID = 'CNRCI_coding_test_genes.gc42.csv'
CODING_TRAIN_ID = 'CNRCI_coding_train_genes.gc42.csv'
NONCODING_TEST_ID = 'CNRCI_noncoding_test_genes.gc42.csv'
NONCODING_TRAIN_ID = 'CNRCI_noncoding_train_genes.gc42.csv'
# Output sequence files
CODING_TEST_SEQ = 'CNRCI_coding_test_transcripts.gc42.csv'
CODING_TRAIN_SEQ = 'CNRCI_coding_train_transcripts.gc42.csv'
NONCODING_TEST_SEQ = 'CNRCI_noncoding_test_transcripts.gc42.csv'
NONCODING_TRAIN_SEQ = 'CNRCI_noncoding_train_transcripts.gc42.csv'

2022-10-28 16:29:49.952992


In [2]:
def load_sequence_data(filepath):
    '''
    Load transcript sequences. Also,
    Load IDs of the genes for which we have sequence.
    The long RNA strings preclude the use of csv.reader utility.
    Expect csv file with this header line:
    transcript_id,gene_id,biotype,length,sequence
    '''
    gene_set=set()
    sequence_data=[]
    with open (filepath) as handle:
        header = None
        for line in handle:
            if header is None:
                header = line
            else:
                line = line.strip()
                fields = line.split(',')
                transcript_id = fields[0]
                gene_id       = fields[1]
                biotype       = fields[2]
                length        = int(fields[3])
                sequence      = fields[4]
                if length != len(sequence):
                    print(line)
                    raise Exception('Lengths do not match')
                gene_set.add(gene_id)
                sequence_data.append(fields)
    gene_list = sorted(list(gene_set))
    return gene_list,sequence_data

In [3]:
def inplace_shuffle(rows):
    generator = Random()
    generator.seed(42)
    generator.shuffle(rows)  # in-place

In [4]:
def train_test_split(rows):
    length = len(rows)
    divider = int(length*TEST_PORTION)
    train_set = rows[divider:]
    test_set = rows[:divider]
    return (train_set,test_set)

In [5]:
def save_csv(rows,filepath):
    with open(filepath,'w') as handle:
        header = 'gene_id'
        handle.write(header)
        handle.write('\n')
        for line in rows:
            handle.write(line)
            handle.write('\n')

In [6]:
def assert_exclusivity(list1,list2):
    set1=set(list1)
    set2=set(list2)
    if len(set1)!=len(list1) or len(set2)!=len(list2):
        raise Exception('Lists contained duplicates')
    intersection = set1 & set2
    if len(intersection)!=0:
        raise Exception('Lists are not exclusive')

In [7]:
def get_atlas_genes(filepath):
    gene_set = set()
    with open (filepath, 'r') as handle:
        header = None
        for row in handle:
            if header is None:
                header = row
            else:
                row = row.strip()
                fields = row.split(',')
                gene_id = fields[0]
                cell_line = fields[1]
                data_type = fields[2]
                value = fields[3]
                if data_type == 'CNRCI' and value != 'NA':
                    gene_set.add(gene_id)
    return gene_set
def get_intersection(all_data, subset_ids):
    intersection = []
    for gene_id in all_data:
        if gene_id in subset_ids:
            intersection.append(gene_id)
    return intersection

## Processing

In [8]:
print(datetime.now())
coding_genes,    coding_sequence    = load_sequence_data(GENCODE_DIR+CODING_ALL)
noncoding_genes, noncoding_sequence = load_sequence_data(GENCODE_DIR+NONCODING_ALL)

print('First few coding genes:')
print(coding_genes[:5])
print('First few noncoding genes:')
print(noncoding_genes[:5])
print(datetime.now())

2022-10-28 16:29:50.106192
First few coding genes:
['ENSG00000000003', 'ENSG00000000005', 'ENSG00000000419', 'ENSG00000000457', 'ENSG00000000460']
First few noncoding genes:
['ENSG00000082929', 'ENSG00000099869', 'ENSG00000105501', 'ENSG00000115934', 'ENSG00000116652']
2022-10-28 16:29:51.041356


In [9]:
atlas_genes = get_atlas_genes(ATLAS_DIR+ATLAS_DATA)
coding_genes = get_intersection(coding_genes, atlas_genes)
noncoding_genes = get_intersection(noncoding_genes, atlas_genes)

In [10]:
print(datetime.now())
inplace_shuffle(coding_genes)
inplace_shuffle(noncoding_genes)

print('First few coding genes:')
print(coding_genes[:5])
print('First few noncoding genes:')
print(noncoding_genes[:5])
print(datetime.now())

2022-10-28 16:29:51.764794
First few coding genes:
['ENSG00000107679', 'ENSG00000164647', 'ENSG00000169299', 'ENSG00000156787', 'ENSG00000101464']
First few noncoding genes:
['ENSG00000185186', 'ENSG00000259005', 'ENSG00000250775', 'ENSG00000266801', 'ENSG00000236352']
2022-10-28 16:29:51.793808


In [11]:
print(datetime.now())
coding_train_set,   coding_test_set    = train_test_split(coding_genes)
noncoding_train_set,noncoding_test_set = train_test_split(noncoding_genes)

print('First few coding train genes:')
print(coding_train_set[:5])
print('First few coding test genes:')
print(coding_test_set[:5])
print('First few noncoding train genes:')
print(noncoding_train_set[:5])
print('First few noncoding test genes:')
print(noncoding_test_set[:5])
print(datetime.now())

2022-10-28 16:29:51.810816
First few coding train genes:
['ENSG00000212659', 'ENSG00000103269', 'ENSG00000156206', 'ENSG00000106608', 'ENSG00000150627']
First few coding test genes:
['ENSG00000107679', 'ENSG00000164647', 'ENSG00000169299', 'ENSG00000156787', 'ENSG00000101464']
First few noncoding train genes:
['ENSG00000213904', 'ENSG00000261766', 'ENSG00000273145', 'ENSG00000259977', 'ENSG00000277463']
First few noncoding test genes:
['ENSG00000185186', 'ENSG00000259005', 'ENSG00000250775', 'ENSG00000266801', 'ENSG00000236352']
2022-10-28 16:29:51.812767


In [12]:
print(datetime.now())
coding_train_sort    = sorted(coding_train_set)
coding_test_sort     = sorted(coding_test_set)
noncoding_train_sort = sorted(noncoding_train_set)
noncoding_test_sort  = sorted(noncoding_test_set)
coding_train_set     = None
coding_test_set      = None
noncoding_train_set  = None
noncoding_test_set   = None

print('First few coding train genes:')
print(coding_train_sort[:5])
print('First few coding test genes:')
print(coding_test_sort[:5])
print('First few noncoding train genes:')
print(noncoding_train_sort[:5])
print('First few noncoding test genes:')
print(noncoding_test_sort[:5])
print(datetime.now())

2022-10-28 16:29:51.843954
First few coding train genes:
['ENSG00000000003', 'ENSG00000000005', 'ENSG00000000419', 'ENSG00000000457', 'ENSG00000000460']
First few coding test genes:
['ENSG00000000938', 'ENSG00000000971', 'ENSG00000001036', 'ENSG00000001460', 'ENSG00000001626']
First few noncoding train genes:
['ENSG00000099869', 'ENSG00000105501', 'ENSG00000116652', 'ENSG00000117242', 'ENSG00000120664']
First few noncoding test genes:
['ENSG00000082929', 'ENSG00000124915', 'ENSG00000130600', 'ENSG00000145063', 'ENSG00000146666']
2022-10-28 16:29:51.878603


In [13]:
print(datetime.now())
print('Save gene ID lists')
filename = SPLIT_DIR + CODING_TEST_ID
print(filename)
save_csv( coding_test_sort, filename)

filename = SPLIT_DIR + CODING_TRAIN_ID
print(filename)
save_csv( coding_train_sort, filename)

filename = SPLIT_DIR + NONCODING_TEST_ID
print(filename)
save_csv( noncoding_test_sort, filename)

filename = SPLIT_DIR + NONCODING_TRAIN_ID
print(filename)
save_csv( noncoding_train_sort, filename)

2022-10-28 16:29:51.890919
Save gene ID lists
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_coding_test_genes.gc42.csv
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_coding_train_genes.gc42.csv
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_noncoding_test_genes.gc42.csv
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_noncoding_train_genes.gc42.csv


In [14]:
print(datetime.now())
print('   Coding total, train, test:',
      len(coding_genes), len(coding_train_sort), len(coding_test_sort))
print('Noncoding total, filtered, train, test:',
      len(noncoding_genes), len(noncoding_train_sort), len(noncoding_test_sort))


2022-10-28 16:29:51.935692
   Coding total, train, test: 17472 13978 3494
Noncoding total, filtered, train, test: 5827 4662 1165


In [15]:
print(datetime.now())
assert_exclusivity(coding_genes,        noncoding_genes)
assert_exclusivity(coding_train_sort,   coding_test_sort)
assert_exclusivity(coding_train_sort,   noncoding_train_sort)
assert_exclusivity(coding_test_sort,    noncoding_test_sort)
assert_exclusivity(coding_test_sort,    noncoding_train_sort)
assert_exclusivity(coding_train_sort,   noncoding_test_sort)
assert_exclusivity(noncoding_train_sort,noncoding_test_sort)

2022-10-28 16:29:51.968653


In [16]:
def save_seq(gene_list,sequence_data,filepath):
    with open(filepath,'w') as handle:
        header = 'transcript_id,gene_id,biotype,length,sequence'
        handle.write(header)
        handle.write('\n')
        valid_ids = set(gene_list)
        for fields in sequence_data:
            transcript_id = fields[0]
            gene_id       = fields[1]
            biotype       = fields[2]
            length        = fields[3]
            sequence      = fields[4]
            if gene_id in valid_ids:
                line = ','.join(fields)
                handle.write(line)
                handle.write('\n')    

In [17]:
print('Save transcript sequences')
filename = SPLIT_DIR + CODING_TEST_SEQ
print(filename)
save_seq( coding_test_sort, coding_sequence, filename)

filename = SPLIT_DIR + CODING_TRAIN_SEQ
print(filename)
save_seq( coding_train_sort, coding_sequence, filename)

filename = SPLIT_DIR + NONCODING_TEST_SEQ
print(filename)
save_seq( noncoding_test_sort, noncoding_sequence, filename)

filename = SPLIT_DIR + NONCODING_TRAIN_SEQ
print(filename)
save_seq( noncoding_train_sort, noncoding_sequence, filename)

Save transcript sequences
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_coding_test_transcripts.gc42.csv
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_coding_train_transcripts.gc42.csv
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_noncoding_test_transcripts.gc42.csv
/Users/jasonmiller/WVU/Localization/TrainTest/CNRCI_noncoding_train_transcripts.gc42.csv


In [18]:
print('done')
print(datetime.now())

done
2022-10-28 16:29:54.295483
