In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import os
import numpy as np
import pandas as pd
from cmapPy.pandasGEXpress.parse import parse

In [None]:
def load_geneid_list(input_file):
    # input: txt file with one gene id per one line
    # output: a list of string-formatted, newline character-removed gene ids
    try:
        if input_file.endswith(".txt"):
            with open(input_file, 'r') as f:
                lst = f.readlines()
            lst = [l[:-1] for l in lst]
            return lst
        else:
            raise ValueError("input file should be a text file")
    except:
        raise ValueError("input should be a string for a path of gene ID list file")
    
def l1000_to_image(input, lmids):
    # input: a pandas dataframe of L1000 input data
    # lmids: a list of landmark gene ids
    # return: a list of numpy arrays with (27x36) shape
    # Note that the values are unchanged by this function; 
    #   so scaling to (0,1) should be done before using this function
    num_lm = len(lmids)
    sample = input.loc[lmids,:]
    lst = []
    for i in range(input.shape[1]):
        sample_i = sample.iloc[:,i].tolist()
        if num_lm == 108:
            sample_i = np.reshape(sample_i, (9,12))
            out_image = np.repeat(np.repeat(sample_i, 3, axis=0), 3, axis=1)
        else:
            if num_lm == 970:
                sample_i = sample_i + [0,0] # add two zeros at the end to fit in the shape (27x36)
            elif num_lm == 486:
                sample_i = np.repeat(sample_i, 2)
            elif num_lm == 324:
                sample_i = np.repeat(sample_i, 3)
            else:
                raise ValueError(f"Invalid length of lmids: {num_lm}")

            out_image =np.reshape(sample_i, (27,36))
        lst.append(out_image)            
    return lst

def rnaseq_to_image(input, lmids):
    # input: a pandas dataframe of RNA-seq ground truth data
    # lmids: a list of landmark gene ids
    # return: a list of numpy arrays with (108x144) shape
    # Note that the values are unchanged by this function; 
    #   so scaling to (0,1) should be done before using this function
    
    num_lm = len(lmids)
    num_nonlm = input.shape[0] - num_lm
    if isinstance(input, pd.Series):
        num_cols= 1
        input = pd.DataFrame(input)
    else:
        num_cols = input.shape[1]
    lst = []
    for i in range(num_cols):
        # if num_cols == 1:
        #     ith_lm = input.loc[lmids]
        #     ith_nonlm = input.drop(index=lmids)
        # else:
        ith_lm = input.loc[lmids,:].iloc[:,i]
        ith_nonlm = input.drop(index=lmids).iloc[:,i]
        cnt = 0
        
        if num_lm == 970:
            ith_lm = np.array(ith_lm.tolist()+[0,0])
            ith_nonlm = ith_nonlm.tolist() + np.zeros(108*144-972*4-num_nonlm).tolist()
            out_image = np.repeat(np.repeat(np.reshape(ith_lm, (27,36)), 4, axis=0),4, axis=1)
            if type(ith_nonlm[0])==str:
                print(np.array(ith_nonlm).dtype)
                out_image = out_image.astype(np.array(ith_nonlm).dtype)
            for r in range(108):
                for c in range(144):
                    if (r%4 in [1,2]) and (c%4 in [1,2]): continue
                    else:
                        if cnt >= len(ith_nonlm): continue
                        out_image[r,c]= ith_nonlm[cnt]
                        cnt += 1

        elif num_lm == 486:
            ith_lm = np.repeat(ith_lm.tolist(), 2)
            ith_nonlm = ith_nonlm.tolist() + np.zeros(108*144 - 486*7 - num_nonlm).tolist()
            out_image = np.repeat(np.repeat(np.reshape(ith_lm, (27,36)), 4, axis=0), 4, axis=1)
            if type(ith_nonlm[0])==str:
                print(np.array(ith_nonlm).dtype)
                out_image = out_image.astype(np.array(ith_nonlm).dtype)
            
            for r in range(108):
                for c in range(144):
                    if (r%4==1 and c%8 in [*range(2,6)]) or (r%4==2 and c%8 in [*range(3,6)]): continue
                    else:
                        if cnt>= len(ith_nonlm): continue
                        out_image[r,c]= ith_nonlm[cnt]
                        cnt += 1
        elif num_lm == 324:
            ith_lm = np.repeat(ith_lm.tolist(), 3)
            ith_nonlm = ith_nonlm.tolist() + np.zeros(108*144-324*10-num_nonlm).tolist()
            out_image = np.repeat(np.repeat(np.reshape(ith_lm, (27,36)), 4, axis=0), 4, axis=1)
            if type(ith_nonlm[0])==str:
                print(np.array(ith_nonlm).dtype)
                out_image = out_image.astype(np.array(ith_nonlm).dtype)
            
            for r in range(108):
                for c in range(144):
                    if (r%4 in [1,2]) and (c%12 in [*range(3,8)]): continue
                    else:
                        if cnt >= len(ith_nonlm): continue
                        out_image[r,c]= ith_nonlm[cnt]
                        cnt += 1
        elif num_lm == 108:
            ith_lm = np.array(ith_lm)
            ith_nonlm = ith_nonlm.tolist() + np.zeros(108*144-108*6*5-num_nonlm).tolist()
            out_image = np.repeat(np.repeat(np.reshape(ith_lm, (9,12)),12, axis=0),12, axis=1)
            
            if type(ith_nonlm[0])==str:
                print(np.array(ith_nonlm).dtype)
                out_image = out_image.astype(np.array(ith_nonlm).dtype)
            
            for r in range(108):
                for c in range(144):
                    if (r%12 in [*range(3,9)]) and (c%12 in [*range(3,8)]): continue
                    else:
                        if cnt >= len(ith_nonlm): continue
                        out_image[r,c]= ith_nonlm[cnt]
                        cnt += 1
        else:
            raise ValueError(f"Invalid length of lmids: {num_lm}")
        lst.append(out_image)
    return lst

def image_to_rnaseq(image, lmids, all_gene_ids):
    # image: should be (108x144) numpy ndarray; the code would be able to handle with pandas DataFrame 
    # lmids, all_gene_ids: list of gene ids or path to the text file of gene ids
    # output: (12320x1) numpy ndarray of inferred values
    # sanity check
    if isinstance(image, pd.DataFrame): image = image.values
    elif not isinstance(image, np.ndarray):
        raise TypeError(f"Invalid image type: {type(image)}; image should be pandas DataFrame or numpy array")
    
    if not type(lmids)==list:
        if type(lmids)==str and lmids.endswith(".txt"):
            lmids = load_geneid_list(lmids)
        else:
            raise TypeError(f"Invalid lmids type: {type(lmids)}; lmids should be list of landmark ids or string of path to landmark ids text file")

    if not type(all_gene_ids)==list:
        if type(all_gene_ids)==str and all_gene_ids.endswith(".txt"):
            all_gene_ids = load_geneid_list(all_gene_ids)
        else:
            raise TypeError(f"Invalid all_gene_ids type: {type(all_gene_ids)}; all_gene_ids should be list of gene ids or string of path to gene ids text file")


    # produce an array indicating which gene id is allocated to the cell
    id_df = pd.DataFrame(all_gene_ids, index=all_gene_ids)
    id_arr = rnaseq_to_image(id_df, lmids)[0]

    if id_arr.shape != image.shape:
        raise ValueError(f"Invalid input image shape: {image.shape}")
    
    # get the inferred values from image
    display(id_arr)
    inferred_dict = {}
    for r in range(image.shape[0]):
        for c in range(image.shape[1]):
            if id_arr[r,c] in all_gene_ids:
                if id_arr[r,c] in inferred_dict.keys():
                    inferred_dict[id_arr[r,c]].append(image[r,c])
                else:
                    inferred_dict[id_arr[r,c]] = [image[r,c]]
    
    # if a gene is inferred by multiple pixels, use the average of values as the inferred
    inferred_arr = []
    for i in all_gene_ids:
        lst = inferred_dict[i]
        inferred_arr.append([np.mean(lst)])
    
    return np.array(inferred_arr)

def produce_training_images(L1000_gctx, RNAseq_gctx, outpath, lmids,
                            train=2500, valid=500, test=176):
    # From gctx files (GTEx; 12320x3176) produce SwinIR-compatible images (.csv files)
    # result: image files will be written as: 
    # outpath/(L1000 or RNAseq)/(train, valid or test)/*.csv
    
    # load required data
    if type(lmids)==str and lmids.endswith('.txt'):
        lmids = load_geneid_list(lmids)
    l1000 = parse(L1000_gctx).data_df.loc[lmids,:]
    rnaseq = parse(RNAseq_gctx).data_df
    if not os.path.exists(outpath): os.makedirs(outpath)
    
    # scaling to (0,1)
    l1000_scaled = l1000 / np.max(l1000)
    l1000_image_list = l1000_to_image(l1000_scaled, lmids)
    rnaseq_scaled = rnaseq / np.max(rnaseq)
    rnaseq_image_list = rnaseq_to_image(rnaseq_scaled, lmids)
    
    # Save L1000 images
    for i in range(len(l1000_image_list)):
        outpath_L1000 = os.path.join(outpath, "L1000")
        # By column numbers, images are divided into three distinct directories
        if i < train:
            this_outpath = os.path.join(outpath_L1000,"train")
        elif i < train+valid:
            this_outpath = os.path.join(outpath_L1000, 'valid')
        elif i < train+valid+test:
            this_outpath = os.path.join(outpath_L1000, 'test')
        outfilename = os.path.join(this_outpath,f"{i:04d}.csv") # 0000.csv ~ 3175.csv
        if not os.path.exists(this_outpath): os.makedirs(this_outpath)
        np.savetxt(outfilename, l1000_image_list[i], delimiter=",")
    
    # Save RNAseq images
    for i in range(len(rnaseq_image_list)):
        outpath_RNAseq = os.path.join(outpath, "RNAseq")
        if i < train:
            this_outpath = os.path.join(outpath_RNAseq,"train")
        elif i < train+valid:
            this_outpath = os.path.join(outpath_RNAseq, 'valid')
        elif i < train+valid+test:
            this_outpath = os.path.join(outpath_RNAseq, 'test')
        outfilename = os.path.join(this_outpath,f"{i:04d}.csv") # 0000.csv ~ 3175.csv
        if not os.path.exists(this_outpath): os.makedirs(this_outpath)
        np.savetxt(outfilename, rnaseq_image_list[i], delimiter=",")