In [13]:
import pandas as pd
import numpy as np
import torch

In [14]:
def pyLoadData(ext, path):
    try:
        # Switch-like structure to read the file based on its extension
        if ext == 'csv':
            data = pd.read_csv(path)
        elif ext == 'tsv':
            data = pd.read_csv(path, sep='\t')
        elif ext == 'xlsx':
            data = pd.read_excel(path, engine='openpyxl')
        else:
            raise ValueError("Invalid file; Please upload a .csv, .tsv, or .xlsx file")
    except Exception as e:
        raise ValueError(f"Failed to read the file: {e}")
    
    return data

data = pyLoadData("csv", "../data/GBMPurity-test-input.csv")
data

Unnamed: 0,GeneName,JAN_P,AAB_P,AAC_P,AAF_P,AAG_P,AAJ_P,AAL_P,AAM_P,AAN_P,...,GLSS-SM-R080_R,GLSS-SM-R080_P,GLSS-SM-R083_R,GLSS-SM-R083_P,GLSS-SM-R088_R,GLSS-SM-R088_P,GLSS-SM-R093_R,GLSS-SM-R093_P,TCGA-06-0125_R,TCGA-06-0125_P
0,AAAS,543.0,704.0,471.0,517.0,655.0,431.0,654.0,567.0,681.0,...,2147.0,2187.0,232.0,500.0,2252.0,2696.0,1587.0,2020.0,1084.0,2602.0
1,AADAT,160.0,203.0,117.0,220.0,259.0,266.0,146.0,145.0,107.0,...,1929.0,1312.0,808.0,842.0,325.0,270.0,511.0,782.0,275.0,918.0
2,AAGAB,518.0,785.0,431.0,289.0,639.0,391.0,615.0,460.0,522.0,...,2321.0,1472.0,1362.0,2076.0,1059.0,1085.0,1655.0,2248.0,1569.0,2147.0
3,AAK1,4397.0,4871.0,3418.0,3521.0,5450.0,4107.0,4469.0,3094.0,4695.0,...,1517.0,621.0,252.0,214.0,158.0,273.0,651.0,644.0,812.0,427.0
4,AAR2,458.0,299.0,421.0,206.0,305.0,333.0,328.0,300.0,309.0,...,2120.0,1912.0,195.0,109.0,1137.0,705.0,1571.0,1673.0,1535.0,3084.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5824,ZW10,378.0,270.0,334.0,262.0,318.0,279.0,264.0,235.0,264.0,...,776.0,738.0,817.0,604.0,542.0,664.0,539.0,606.0,744.0,1027.0
5825,ZWILCH,75.0,185.0,102.0,62.0,133.0,101.0,151.0,116.0,44.0,...,1344.0,1131.0,2764.0,2753.0,377.0,673.0,563.0,741.0,328.0,1197.0
5826,ZWINT,237.0,268.0,324.0,177.0,167.0,100.0,200.0,194.0,97.0,...,2900.0,1296.0,2004.0,971.0,1386.0,1799.0,1147.0,1136.0,180.0,1092.0
5827,ZXDC,2443.0,1981.0,2225.0,1100.0,925.0,1171.0,1685.0,1076.0,1585.0,...,943.0,1237.0,120.0,212.0,607.0,630.0,543.0,1164.0,1053.0,2158.0


In [15]:
def tpm(X: np.ndarray, lengths: np.ndarray):
    """
    Calculate TPM (Transcripts Per Million) normalization for RNA-seq data.

    Parameters:
    X (np.ndarray): 2D array of raw read counts (sample x genes).
    lengths (np.ndarray): 1D array of feature lengths (e.g., gene lengths).

    Returns:
    np.ndarray: TPM normalized values.
    """
    
    if X.shape[1] != lengths.shape[0]:
        raise ValueError("The number of rows in X must match the length of lengths")
    
    # Calculate RPK (Reads Per Kilobase)
    rpk = np.divide(X, lengths)
    
    # Calculate the scaling factor
    scaling_factor = np.nansum(rpk, axis=1).reshape(-1, 1)
    
    # Calculate TPM
    tpm = (rpk / scaling_factor) * 1e4
    
    return tpm

In [16]:
def pyCheckData(df):
    
    errors = []
    warnings = []
    
    # Import required genes
    gene_lengths = pd.read_csv("../data/GBMPurity_genes.csv")
    genes = gene_lengths['feature_name']
    
    # Check Dimensions
    if df.shape[0] < 1:
        errors.append("We didn't detect any genes.")
        return errors, warnings, None
        
    if df.shape[1] <= 1:
        errors.append("We didn't detect any samples.")
        return errors, warnings, None
    
    # Missing values
    if df.isnull().values.any():
        warnings.append(f"We found {df.isnull().values.sum()} missing values. These will be converted to 0.")
        df = df.fillna(0)
    
    # Check for duplicate genes
    input_genes = df.iloc[:,0]
    duplicate_genes = input_genes[input_genes.duplicated()].unique()
    if len(duplicate_genes) > 0:
        warnings.append(f'We found {len(duplicate_genes)} duplicate genes. Counts for these genes will be summed for each sample.')
    
    data = df.set_index(df.columns[0])
    data = data.groupby(data.index).sum()
    
    # Check appropriate genes
    overlap = set(input_genes).intersection(set(genes)) 
    if len(overlap) == 0:
        errors.append("We didn't find any required genes. Are the provided genes in the HGNC format e.g. CD47?")
        return errors, warnings, None
    else:
        p_overlap = len(overlap)/len(genes)
        if p_overlap < 0.8:
            errors.append(f"We only found {p_overlap * 100}% of the required genes. Purity estimates will be unreliable under 80%.")
            return errors, warnings, None
        elif p_overlap < 0.99:
            warnings.append(f"We found {int(p_overlap * 100)}% of the required genes. Note that GBMPurity tends to underestimate the tumour purity with more missing genes.")
    

    # Check correct data
    # Non-numeric
    non_numeric = data.apply(lambda s: pd.to_numeric(s, errors='coerce').notnull().all()).all()
    if not non_numeric:
        errors.append("All gene expression values must be numeric.")
        return errors, warnings, None

    # Negative values
    if (data.values < 0).any():
        errors.append("Gene expression values must be non-negative. Data should be uploaded as raw counts (without batch correction).")
        return errors, warnings, None
    
    # Process data to return
    data = data.T
    data = data.reindex(columns=genes, fill_value=0)

    # Check 
    return errors, warnings, data


e, w, data = pyCheckData(data)

display(data)

feature_name      AAAS  AADAT   AAGAB    AAK1    AAR2  AARSD1  AASDHPPT  \
JAN_P            543.0  160.0   518.0  4397.0   458.0   429.0     706.0   
AAB_P            704.0  203.0   785.0  4871.0   299.0   397.0    1078.0   
AAC_P            471.0  117.0   431.0  3418.0   421.0   360.0     788.0   
AAF_P            517.0  220.0   289.0  3521.0   206.0   326.0     555.0   
AAG_P            655.0  259.0   639.0  5450.0   305.0   320.0     747.0   
...                ...    ...     ...     ...     ...     ...       ...   
GLSS-SM-R088_P  2696.0  270.0  1085.0   273.0   705.0   680.0    1116.0   
GLSS-SM-R093_R  1587.0  511.0  1655.0   651.0  1571.0  1209.0    2911.0   
GLSS-SM-R093_P  2020.0  782.0  2248.0   644.0  1673.0  1132.0    2502.0   
TCGA-06-0125_R  1084.0  275.0  1569.0   812.0  1535.0   859.0    2860.0   
TCGA-06-0125_P  2602.0  918.0  2147.0   427.0  3084.0  1477.0    3777.0   

feature_name      AASS   AATK   ABCA5  ...  ZSCAN21  ZSCAN5A  ZSCAN9  ZSWIM4  \
JAN_P           154

feature_name,AAAS,AADAT,AAGAB,AAK1,AAR2,AARSD1,AASDHPPT,AASS,AATK,ABCA5,...,ZSCAN21,ZSCAN5A,ZSCAN9,ZSWIM4,ZSWIM6,ZW10,ZWILCH,ZWINT,ZXDC,ZYX
JAN_P,543.0,160.0,518.0,4397.0,458.0,429.0,706.0,1541.0,717.0,1540.0,...,380.0,389.0,381.0,524.0,1222.0,378.0,75.0,237.0,2443.0,3112.0
AAB_P,704.0,203.0,785.0,4871.0,299.0,397.0,1078.0,3087.0,284.0,331.0,...,384.0,394.0,403.0,306.0,1102.0,270.0,185.0,268.0,1981.0,4670.0
AAC_P,471.0,117.0,431.0,3418.0,421.0,360.0,788.0,1421.0,68.0,510.0,...,351.0,220.0,244.0,239.0,1255.0,334.0,102.0,324.0,2225.0,2421.0
AAF_P,517.0,220.0,289.0,3521.0,206.0,326.0,555.0,1837.0,104.0,1460.0,...,368.0,285.0,387.0,476.0,886.0,262.0,62.0,177.0,1100.0,2649.0
AAG_P,655.0,259.0,639.0,5450.0,305.0,320.0,747.0,874.0,126.0,526.0,...,159.0,97.0,331.0,263.0,1164.0,318.0,133.0,167.0,925.0,3495.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GLSS-SM-R088_P,2696.0,270.0,1085.0,273.0,705.0,680.0,1116.0,1273.0,977.0,558.0,...,517.0,127.0,318.0,457.0,799.0,664.0,673.0,1799.0,630.0,8508.0
GLSS-SM-R093_R,1587.0,511.0,1655.0,651.0,1571.0,1209.0,2911.0,569.0,456.0,1004.0,...,360.0,201.0,451.0,236.0,1201.0,539.0,563.0,1147.0,543.0,1919.0
GLSS-SM-R093_P,2020.0,782.0,2248.0,644.0,1673.0,1132.0,2502.0,1099.0,61.0,1286.0,...,593.0,371.0,830.0,236.0,1433.0,606.0,741.0,1136.0,1164.0,3447.0
TCGA-06-0125_R,1084.0,275.0,1569.0,812.0,1535.0,859.0,2860.0,900.0,723.0,870.0,...,611.0,332.0,404.0,511.0,1272.0,744.0,328.0,180.0,1053.0,11983.0


In [20]:
def GBMPurity(data):
    
    # Import gene lengths
    gene_lengths = pd.read_csv("../data/GBMPurity_genes.csv")
    lengths = gene_lengths['feature_length'].values

    # Transform input data
    X = np.log2(tpm(data.values, lengths) + 1)
    
    # Import model
    model = torch.load("./model/GBMPurity.pt")
    model.eval()
    
    # Input to GBMPurity
    y_pred = model(torch.tensor(X).float()).detach().numpy().flatten().clip(0, 1)
    
    samples = data.index.values
    results = pd.DataFrame({'Sample':samples, 'Purity':y_pred.round(3)})
    return results


GBMPurity(data)

Unnamed: 0,Sample,Purity
0,JAN_P,0.544
1,AAB_P,0.609
2,AAC_P,0.823
3,AAF_P,0.705
4,AAG_P,0.208
...,...,...
603,GLSS-SM-R088_P,0.456
604,GLSS-SM-R093_R,0.317
605,GLSS-SM-R093_P,0.670
606,TCGA-06-0125_R,0.719


In [18]:
# Unit testing
valid_data = pd.DataFrame({
    'GeneName': ['AAAS', 'AADAT', 'AAGAB', 'AAK1', 'AAR2'],
    'Sample1': [543, 160, 518, 4397, 458],
    'Sample2': [704, 203, 785, 4871, 299]
})

no_genes_data = pd.DataFrame({
    'GeneName': [],
    'Sample1': [],
    'Sample2': []
})

no_samples_data = pd.DataFrame({
    'GeneName': ['AAAS', 'AADAT', 'AAGAB', 'AAK1', 'AAR2']
})

duplicate_genes_data = pd.DataFrame({
    'GeneName': ['AAAS', 'AADAT', 'AADAT', 'AAGAB', 'AAK1'],
    'Sample1': [543, 160, 160, 518, 4397],
    'Sample2': [704, 203, 203, 785, 4871]
})

no_required_genes_data = pd.DataFrame({
    'GeneName': ['UNKNOWN1', 'UNKNOWN2', 'UNKNOWN3'],
    'Sample1': [543, 160, 518],
    'Sample2': [704, 203, 785]
})

missing_values_data = pd.DataFrame({
    'GeneName': ['AAAS', 'AADAT', 'AAGAB', 'AAK1', 'AAR2'],
    'Sample1': [543, 160, None, 4397, 458],
    'Sample2': [704, None, 785, 4871, 299]
})

non_numeric_data = pd.DataFrame({
    'GeneName': ['AAAS', 'AADAT', 'AAGAB', 'AAK1', 'AAR2'],
    'Sample1': [543, 'one-sixty', 518, 4397, 458],
    'Sample2': [704, 203, 'seven-eighty-five', 4871, 299]
})

negative_values_data = pd.DataFrame({
    'GeneName': ['AAAS', 'AADAT', 'AAGAB', 'AAK1', 'AAR2'],
    'Sample1': [543, -160, 518, 4397, 458],
    'Sample2': [704, 203, 785, -4871, 299]
})


In [19]:
e, w, test_data = pyCheckData(no_samples_data)
print(e)
print(w)

if len(e) == 0:
    purities = GBMPurity(test_data)
    display(purities)

["We didn't detect any samples."]
[]


In [None]:
# Test Data Preparation
no_genes_df = pd.DataFrame(columns=["GeneName", "Sample1", "Sample2"])
no_genes_df.to_csv("../data/test/no_genes.csv", index=False)

no_samples_df = pd.DataFrame({"GeneName": ["Gene1", "Gene2"]})
no_samples_df.to_csv("../data/test/no_samples.csv", index=False)

missing_values_df = pd.read_csv("../data/GBMPurity-test-input.csv")
missing_values_df.iloc[0, 1] = np.nan
missing_values_df.to_csv("../data/test/missing_values.csv", index=False)

duplicate_genes_df = pd.read_csv("../data/GBMPurity-test-input.csv")
duplicate_genes_df = duplicate_genes_df.append(duplicate_genes_df.iloc[0])
duplicate_genes_df.to_csv("../data/test/duplicate_genes.csv", index=False)

no_required_genes_df = pd.read_csv("../data/GBMPurity-test-input.csv")
no_required_genes_df["GeneName"] = ["NonHGNC1", "NonHGNC2"]
no_required_genes_df.to_csv("../data/test/no_required_genes.csv", index=False)

non_numeric_df = pd.read_csv("../data/GBMPurity-test-input.csv")
non_numeric_df.iloc[0, 1] = "non-numeric"
non_numeric_df.to_csv("../data/test/non_numeric_values.csv", index=False)

negative_values_df = pd.read_csv("../data/GBMPurity-test-input.csv")
negative_values_df.iloc[0, 1] = -1
negative_values_df.to_csv("../data/test/negative_values.csv", index=False)

less_than_80_percent_genes_df = pd.read_csv("../data/GBMPurity-test-input.csv")
gene_lengths = pd.read_csv("data/GBMPurity_genes.csv")
required_genes = gene_lengths['feature_name']
less_than_80_percent_genes = required_genes[:int(0.79 * len(required_genes))]
less_than_80_percent_genes_df = less_than_80_percent_genes_df[less_than_80_percent_genes_df['GeneName'].isin(less_than_80_percent_genes)]
less_than_80_percent_genes_df.to_csv("../data/test/less_than_80_percent_genes.csv", index=False)

between_80_and_99_percent_genes_df = pd.read_csv("../data/GBMPurity-test-input.csv")
between_80_and_99_percent_genes = required_genes[:int(0.95 * len(required_genes))]
between_80_and_99_percent_genes_df = between_80_and_99_percent_genes_df[between_80_and_99_percent_genes_df['GeneName'].isin(between_80_and_99_percent_genes)]
between_80_and_99_percent_genes_df.to_csv("../data/test/between_80_and_99_percent_genes.csv", index=False)
