# Train Test Split

Choose 20% of genes to be the unseen test set.   
For each test gene, remove all its transcripts from the training set.  

The task is simplified by our upstream data processing.   
Instead of one csv row per (gene_id,cell_line), we have one row per (gene_id).  

This notebook reads lines of text, shuffles them, partitions them, writes them.
The code ignores the csv nature of each line.   
However, the code takes the header line from input and repeats it in each output.

In [1]:
from random import Random
from datetime import datetime
print(datetime.now())

TEST_PORTION = 0.2
ATLAS_DIR='/Users/jasonmiller/WVU/Localization/LncAtlas/'
ATLAS_CODING='CNRCI_coding_genes.csv'
ATLAS_NONCODING='CNRCI_noncoding_genes.csv'
GENCODE_DIR='/Users/jasonmiller/WVU/Localization/GenCode/'
GENCODE_CODING='Homo_sapiens.GRCh38.cds.csv'
GENCODE_NONCODING='Homo_sapiens.GRCh38.ncrna.csv'

CODING_TEST = 'CNRCI_coding_test_genes.csv'
CODING_TRAIN = 'CNRCI_coding_train_genes.csv'
NONCODING_TEST = 'CNRCI_noncoding_test_genes.csv'
NONCODING_TRAIN = 'CNRCI_noncoding_train_genes.csv'

2022-10-11 17:38:57.180916


In [2]:
def load_available_genes(filepath):
    '''Load IDs of the genes for which we have sequence.
    We will ignore ATLAS data if GenCode does not have the sequence.'''
    genes=set()
    with open (filepath) as handle:
        # The long RNA strings preclude the use of csv.reader
        header = None
        for line in handle:
            if header is None:
                header = line
            else:
                fields = line.split(',')
                gene_id = fields[1]
                genes.add(gene_id)
    return genes

In [3]:
print(datetime.now())
coding_genes = load_available_genes(GENCODE_DIR+GENCODE_CODING)
noncoding_genes = load_available_genes(GENCODE_DIR+GENCODE_NONCODING)
print(datetime.now())

2022-10-11 17:38:57.208841
2022-10-11 17:38:57.704407


In [4]:
def load_all_rows(filepath):
    '''Load ATLAS csv (coding or noncoding files)'''
    with open (filepath,'r') as handle:
        all_rows = handle.readlines()
    header = all_rows[0]
    data = all_rows[1:]
    return header,data

In [5]:
print(datetime.now())
header_coding,all_coding_rows = load_all_rows(ATLAS_DIR+ATLAS_CODING)
header_noncoding,all_noncoding_rows = load_all_rows(ATLAS_DIR+ATLAS_NONCODING)
print(datetime.now())

2022-10-11 17:38:57.750389
2022-10-11 17:38:57.785881


In [6]:
def remove_unknown_genes(atlas,genes):
    filtered_rows = []
    for row in atlas:
        fields = row.split(',')
        gene_id = fields[0]
        if gene_id in genes:
            filtered_rows.append(row)
    return filtered_rows

In [7]:
print(datetime.now())
some_coding_rows = remove_unknown_genes(all_coding_rows,coding_genes)
some_noncoding_rows = remove_unknown_genes(all_noncoding_rows,noncoding_genes)
print(datetime.now())

2022-10-11 17:38:57.813903
2022-10-11 17:38:57.853242


In [8]:
def inplace_shuffle(rows):
    generator = Random()
    generator.seed(42)
    generator.shuffle(rows)  # in-place
inplace_shuffle(some_coding_rows)
inplace_shuffle(some_coding_rows)

In [9]:
def train_test_split(rows):
    length = len(rows)
    divider = int(length*TEST_PORTION)
    train_set = rows[divider:]
    test_set = rows[:divider]
    return (train_set,test_set)
coding_train_set,coding_test_set = train_test_split(some_coding_rows)
noncoding_train_set,noncoding_test_set = train_test_split(some_noncoding_rows)

In [10]:
def save_csv(header,rows,filepath):
    with open(filepath,'w') as handle:
        handle.write(header)
        for line in rows:
            handle.write(line)

In [11]:
filename = ATLAS_DIR + CODING_TEST
save_csv( header_coding, coding_test_set, filename)

filename = ATLAS_DIR + CODING_TRAIN
save_csv( header_coding, coding_train_set, filename)

filename = ATLAS_DIR + NONCODING_TEST
save_csv( header_noncoding, noncoding_test_set, filename)

filename = ATLAS_DIR + NONCODING_TRAIN
save_csv( header_noncoding, noncoding_train_set, filename)

In [12]:
print('done')

done


In [13]:
print('   Coding total, filtered, train, test:',
      len(all_coding_rows),len(some_coding_rows),len(coding_train_set),len(coding_test_set))
print('Noncoding total, filtered, train, test:',
      len(all_noncoding_rows),len(some_noncoding_rows),len(noncoding_train_set),len(noncoding_test_set))


   Coding total, filtered, train, test: 17770 17412 13930 3482
Noncoding total, filtered, train, test: 6768 5777 4622 1155
