In [1]:
import math
import string
import itertools

import numpy as np
from matplotlib import pyplot as plt
import pandas as pd

from data_loading import (load_isoform_and_paralog_y2h_data,
                          load_y2h_isoform_data,
                          load_y2h_paralogs_additional_data,
                          load_paralog_pairs)

In [2]:
y2h = load_isoform_and_paralog_y2h_data()
y2h_para = load_y2h_paralogs_additional_data()
y2h_para = y2h_para.loc[y2h_para['at_least_2_isoforms'] & y2h_para['at_least_2_partners'], :]
pairs = load_paralog_pairs()

# could also try with valid clones dataset
y2h['at_least_2_isoforms'] = y2h['ad_gene_symbol'].map(y2h.loc[y2h['category'] == 'tf_isoform_ppis', :].groupby('ad_gene_symbol')['ad_clone_acc'].nunique() >= 2)
y2h['at_least_2_partners'] = y2h['ad_gene_symbol'].map(y2h.loc[y2h['category'] == 'tf_isoform_ppis', :].groupby('ad_gene_symbol')['db_gene_symbol'].nunique() >= 2)
y2h_iso = y2h.loc[(y2h['category'] == 'tf_isoform_ppis') &
                  y2h['at_least_2_isoforms'] &
                  y2h['at_least_2_partners'],
                  :]

In [3]:
y2h_iso.shape

(5447, 16)

In [4]:
y2h_para.shape

(4311, 17)

In [5]:
pairs.shape

(148, 4)

In [6]:
filtered_pairs = y2h_para.loc[y2h_para['at_least_2_isoforms'] & 
             y2h_para['at_least_2_partners'],
             ['ad_gene_symbol', 'paired_tf_gene']].drop_duplicates()
filtered_pairs = {frozenset([row['ad_gene_symbol'], g]) for _i, row in filtered_pairs.iterrows() for g in row['paired_tf_gene'].split('|')}

In [7]:
pairs['additional_tested'] = pairs.apply(lambda x: frozenset([x['tf_gene_a'], x['tf_gene_b']]) in filtered_pairs,
            axis=1)

In [8]:
# check for pairs where one set of partners is a subset of the other
iso_partners = y2h_iso.groupby('ad_gene_symbol')['db_gene_symbol'].apply(set)
pairs['partners_a'] = pairs['tf_gene_a'].map(iso_partners)
pairs['partners_b'] = pairs['tf_gene_b'].map(iso_partners)
pairs['same_tested_partners'] = (pairs['partners_a'] == pairs['partners_b'])

In [9]:
pairs.loc[pairs['additional_tested'] | pairs['same_tested_partners'],
          ['tf_gene_a', 'tf_gene_b', 'is_paralog_pair']].to_csv('../../cache/paralog_pairs_filtered.tsv', sep='\t', index=False)

In [10]:
pairs.loc[pairs['additional_tested'] | pairs['same_tested_partners'],
          ['tf_gene_a', 'tf_gene_b', 'is_paralog_pair']]['is_paralog_pair'].value_counts()

True     84
False    12
Name: is_paralog_pair, dtype: int64

In [11]:
non_zero_iso = (set(y2h_iso.loc[y2h_iso['score'] == '1', 'ad_clone_acc'].unique())
                .union(
                    set(y2h_para.loc[y2h_para['score'] == '1', 'ad_clone_acc'].unique())
                ))

In [12]:
y2h_iso['non_zero_iso'] = y2h_iso['ad_clone_acc'].isin(non_zero_iso)
y2h_para['non_zero_iso'] = y2h_para['ad_clone_acc'].isin(non_zero_iso)

In [13]:
y2h_iso['at_least_2_non_zero_isoforms'] = (y2h_iso['ad_gene_symbol']
                                            .map(y2h_iso.loc[y2h_iso['non_zero_iso'], :]
                                                  .groupby('ad_gene_symbol')
                                                  ['ad_clone_acc']
                                                  .nunique() >= 2))

In [14]:
y2h_iso['at_least_1_positive_per_partner'] = (y2h_iso.groupby(['ad_gene_symbol', 'db_gene_symbol'])
                                           ['score']
                                          .transform(lambda row: (row == '1').any()))
# and at least two partners after excluding those without a postiive
y2h_iso['at_least_1_positive_per_partner'] = (y2h_iso['at_least_1_positive_per_partner'] &
                                         y2h_iso['ad_gene_symbol'].map(
                                             y2h_iso.loc[y2h_iso['at_least_1_positive_per_partner'],
                                                     :].groupby('ad_gene_symbol')['db_gene_symbol'].nunique() >= 2))

In [15]:
(y2h_iso['at_least_1_positive_per_partner'] &
 y2h_iso['at_least_2_non_zero_isoforms']).sum()

3993

In [16]:
non_zero_gene_ppis = set((y2h_iso.loc[y2h_iso['at_least_1_positive_per_partner'] &
                                y2h_iso['at_least_2_non_zero_isoforms'],
                                ['ad_gene_symbol', 'db_gene_symbol']]
                                .drop_duplicates()
                                .apply(lambda x: x['ad_gene_symbol'] + '_' + x['db_gene_symbol'],
                                        axis=1)).values)
y2h_para['matches_non_zero_pair'] = y2h_para.apply(lambda x: any((g + '_' + x['db_gene_symbol']) in non_zero_gene_ppis
                             for g in x['paired_tf_gene'].split('|')),
               axis=1)

In [17]:
(y2h_para['matches_non_zero_pair'] & y2h_para['non_zero_iso']).sum()

2533

In [18]:
y2h_para.head()

Unnamed: 0,category,ad_orf_id,ad_clone_acc,ad_gene_symbol,db_orf_id,db_gene_symbol,score,standard_batch,retest_pla,retest_pos,in_orfeome_screen,in_focussed_screen,in_hi_union,in_lit_bm,paired_tf_gene,at_least_2_isoforms,at_least_2_partners,non_zero_iso,matches_non_zero_pair
9888,tf_paralog_ppis,100553,TCF12|2/3|07A09,TCF12,1270,OLFM3,0,TFr08,92,A02,False,False,False,False,TCF4,True,True,True,True
9889,tf_paralog_ppis,100383,NR2F2|1/2|09E07,NR2F2,4558,RARA,0,TFr08,92,A03,False,False,False,False,RXRB|RXRA|RXRG,True,True,True,False
9890,tf_paralog_ppis,101051,TP53|2/2|02F03,TP53,6901,MCRS1,0,TFr08,92,A04,False,False,False,False,TP63,True,True,True,True
9891,non_paralog_control,101157,HMG20A|1/2|05B08,HMG20A,7215,NR0B1,0,TFr08,92,A05,False,False,False,False,HNF4A,True,True,True,True
9892,tf_paralog_ppis,100318,NR4A1|3/5|03G02,NR4A1,7215,NR0B1,0,TFr08,92,A06,False,False,False,False,ESRRG|ESRRA,True,True,True,True


In [19]:
# restrict paralog pairs to
#y2h_iso.loc[y2h_iso['at_least_1_positive_per_partner'] & y2h_iso]

In [20]:
(y2h_iso.shape[0] + y2h_para.shape[0])

9758

In [21]:
(y2h_iso['non_zero_iso'].sum() + y2h_para['non_zero_iso'].sum())

7470

In [22]:
columns = ['ad_orf_id', 'ad_clone_acc', 'ad_gene_symbol', 'db_orf_id', 'db_gene_symbol']
df = (pd.concat([y2h_iso.loc[:, columns],
                y2h_para.loc[:, columns]])
                .groupby('ad_gene_symbol')[['ad_clone_acc',
        'db_gene_symbol']]
        .nunique()
        .rename(columns={'ad_clone_acc': 'n_isoforms',
                                                     'db_gene_symbol': 'n_partners'})
                                                     .sort_values(['n_isoforms', 'n_partners'],
                                                                  ascending=False))

In [23]:
df = pd.concat([y2h_iso.loc[:, columns],
          y2h_para.loc[:, columns]])
gene_sizes = (df.groupby('ad_gene_symbol')[['ad_clone_acc',
        'db_gene_symbol']]
        .nunique()
        .rename(columns={'ad_clone_acc': 'n_isoforms',
                                                     'db_gene_symbol': 'n_partners'})
                                                     .sort_values(['n_isoforms', 'n_partners'],
                                                                  ascending=False))
df['ad_clone_acc'] = df['ad_clone_acc'].apply(lambda x: x.split('|')[0] + '-' + x.split('|')[1].split('/')[0])
df = df.sort_values(['ad_gene_symbol' , 'ad_clone_acc', 'db_gene_symbol'])
if df.duplicated().any():
    raise UserWarning('Unexpected duplicate rows in table')

In [24]:
gene_sizes.head()

Unnamed: 0_level_0,n_isoforms,n_partners
ad_gene_symbol,Unnamed: 1_level_1,Unnamed: 2_level_1
TCF4,8,198
ZNF451,7,94
SOX6,6,39
ZBTB44,6,25
GRHL3,6,17


In [25]:
gene_sizes.tail()

Unnamed: 0_level_0,n_isoforms,n_partners
ad_gene_symbol,Unnamed: 1_level_1,Unnamed: 2_level_1
RARB,2,3
ZFY,2,3
ZNF250,2,3
FOXJ2,2,2
HSF2,2,2


In [26]:
counts = gene_sizes.groupby(['n_isoforms', 'n_partners']).size().to_frame().rename(columns={0: 'n_tf_genes'}).reset_index()
counts.head(20)

Unnamed: 0,n_isoforms,n_partners,n_tf_genes
0,2,2,2
1,2,3,5
2,2,4,5
3,2,5,2
4,2,6,2
5,2,7,1
6,2,8,1
7,2,9,7
8,2,10,1
9,2,12,1


In [27]:
# 7 by 12
counts.prod(axis=1).sum()

9758

In [28]:
np.round(counts.prod(axis=1).sum() / (7*12))

116.0

In [29]:
dat = pd.concat([y2h_iso.loc[:, columns],
           y2h_para.loc[:, columns]])
dat['ad_clone_acc'].nunique() + dat['db_gene_symbol'].nunique()

1084

In [30]:
(pd.concat([y2h_iso.loc[:, :],
           y2h_para.loc[:, :]])['score'] == '1').sum()

2593

In [31]:
2600 / (7*12)

30.952380952380953

In [32]:
np.round((counts.prod(axis=1).sum() + dat['ad_clone_acc'].nunique() + dat['db_gene_symbol'].nunique()) / (7*12))

129.0

In [33]:
dat['db_gene_symbol'].nunique()

776

In [34]:
(dat.groupby(['ad_gene_symbol', 'db_gene_symbol']).size() > 0).sum()

2815

In [35]:
# TODO:
# add empty AD for Lit-BM and RRS
# try and reduce the number of plates used
# get unique empty plate code

In [36]:
class Plate:
    def __init__(self, n_rows, n_columns, empty_name='empty'):
        self.n_rows = n_rows
        self.n_columns = n_columns
        self.empty_name = empty_name
        self.grid = [[empty_name for i in range(n_columns)] for j in range(n_rows)]

    def add_matrix(self, matrix, pos_top_left, transpose=False):
        n_rows = len(matrix)
        n_columns = len(matrix[0])
        if n_rows + pos_top_left[0] > self.n_rows or n_columns + pos_top_left[1] > self.n_columns:
            raise ValueError('Does not fit: {} {} {} \n{}'.format(n_rows,
                                                                  n_columns,
                                                                  pos_top_left,
                                                                  self))
        for i, row in enumerate(matrix):
            row_index = i + pos_top_left[0]
            for j, well in enumerate(row):
                column_index = j + pos_top_left[1]
                if self.grid[row_index][column_index] != self.empty_name:
                    raise ValueError('Well already occupied: {} {}\n{}'.format(row_index, column_index, self))
                self.grid[row_index][column_index] = well

    def lowest_unoccupied_row(self):
        for i in range(self.n_rows):
            if self.grid[i] == [self.empty_name] * self.n_columns:
                return i
        return None

    def is_empty(self):
        return all(all(x == self.empty_name for x in row) for row in self.grid)

    def row_is_empty(self, row_index):
        return all(x == self.empty_name for x in self.grid[row_index])

    def empty_wells(self):
        return {string.ascii_uppercase[i] + str(j + 1).zfill(2) for i, row in enumerate(self.grid) 
                                                                for j, x in enumerate(row) 
                                                                if x == self.empty_name}

    def __repr__(self):
        return '\n'.join('|'.join(row) for row in self.grid)

In [39]:
def solve_plate_layout(genes, row_max=7, column_max=12):
    """

    - start 

    """
    plates = [Plate(n_rows=8, n_columns=12)]
    row_count, column_count = 0, 0

    # TODO: sort genes: start with largest number isoforms

    unallocated = genes.copy()
    while len(unallocated) > 0:
        gene_fits = False
        for tf_gene_name, pair_matrix in unallocated.items():
            n_rows = len(pair_matrix)
            n_col_gene = len(pair_matrix[0])

            # rotate matrix for genes with too many isoforms to fit on a plate
            if n_rows > row_max and plates[-1].is_empty():
                print(tf_gene_name, 'rotated')
                row_max_transpose = 8
                gene_fits = True
                pair_matrix = [[pair_matrix[i][j] for i in range(len(pair_matrix))] for j in range(len(pair_matrix[0]))]
                n_row_gene = len(pair_matrix)
                n_cols = len(pair_matrix[0])
                n_rows_split = [row_max_transpose] * (n_row_gene // row_max_transpose) + ([n_row_gene % row_max_transpose] if n_row_gene % row_max_transpose > 0 else [])
                for i, n_rows in enumerate(n_rows_split):
                    matrix_subset = pair_matrix[i * row_max_transpose:i * row_max_transpose + n_rows]
                    if not plates[-1].is_empty():
                        plates.append(Plate(n_rows=8, n_columns=12))
                    plates[-1].add_matrix(matrix_subset,
                                            pos_top_left=(0, 0))
                    row_count = 0
                    column_count = n_cols
                break

            if n_col_gene <= column_max:  # fits on one row
                if column_count + n_col_gene <= column_max and row_count + n_rows <= row_max:  # add to the right
                    plates[-1].add_matrix(pair_matrix,
                                          pos_top_left=(row_count, column_count))
                    column_count += n_col_gene
                    gene_fits = True
                    break
                else:
                    row_below = plates[-1].lowest_unoccupied_row()
                    if row_below is not None and (row_below + n_rows <= row_max):  # add below
                        plates[-1].add_matrix(pair_matrix,
                                              pos_top_left=(row_below, 0))
                        column_count = n_col_gene
                        row_count = row_below
                        gene_fits = True
                        break
            elif plates[-1].is_empty():  # gene doesn't fit on a single row
                gene_fits = True
                for i, n_cols in enumerate([column_max] * (n_col_gene // column_max) + ([n_col_gene % column_max] if n_col_gene % column_max > 0 else [])):
                    matrix_subset = [row[i * column_max:i * column_max + n_cols] for row in pair_matrix]
                    if column_count + n_cols <= column_max and row_count + n_rows <= row_max:  # add to the right
                        plates[-1].add_matrix(matrix_subset,
                                              pos_top_left=(row_count, column_count))
                        column_count += n_cols
                    else:
                        row_below = plates[-1].lowest_unoccupied_row()
                        if row_below is not None and (row_below + n_rows <= row_max):  # add below
                            plates[-1].add_matrix(matrix_subset,
                                                  pos_top_left=(row_below, 0))
                            column_count = n_cols
                            row_count = row_below
                        else:  # start a new plate
                            plates.append(Plate(n_rows=8, n_columns=12))
                            plates[-1].add_matrix(matrix_subset,
                                                  pos_top_left=(0, 0))
                            row_count = 0
                            column_count = n_cols
                break
        if gene_fits:
            del unallocated[tf_gene_name]
        else:  # start a new plate
            plates.append(Plate(n_rows=8, n_columns=12))
            row_count = 0
            column_count = 0
    return plates


gene_matrices = {}
for gene in df['ad_gene_symbol'].unique():
    gene_matrices[gene] = []
    for i, isoform in enumerate(list(df.loc[df['ad_gene_symbol'] == gene, 'ad_clone_acc'].unique()) + ['empty-AD']):
        gene_matrices[gene].append([])
        for partner in df.loc[df['ad_gene_symbol'] == gene, 'db_gene_symbol'].unique():
            gene_matrices[gene][i].append(isoform + '/' + partner)
plates = solve_plate_layout(gene_matrices)

# lit-bm and rrs pairs...
litbm_rrs = y2h.loc[y2h['category'].isin(['rrs_isoforms', 'lit_bm_isoforms']), ['ad_clone_acc', 'db_gene_symbol', 'category']].drop_duplicates().copy()
litbm_rrs['ad_clone_acc'] = litbm_rrs['ad_clone_acc'].apply(lambda x: x.split('|')[0] + '-' + x.split('|')[1].split('/')[0])
seed = 307272992
litbm_rrs['pair'] = litbm_rrs['ad_clone_acc'] + '/' + litbm_rrs['db_gene_symbol']
litbm_rrs['already_tested'] = litbm_rrs['pair'].isin((df['ad_clone_acc'] + '/' + df['db_gene_symbol']).values)
litbm_rrs_pairs = litbm_rrs.loc[~litbm_rrs['already_tested'], 'pair'].to_list()
np.random.seed(seed)
np.random.shuffle(litbm_rrs_pairs)
# TODO: add empty-AD for lit-bm and rrs
for plate, code in zip(plates, itertools.cycle(itertools.product([False, True], repeat=6))):
    if plate.row_is_empty(7):
        plate.add_matrix([['control-' + str(i) for i in range(1, 7)] +
                          [litbm_rrs_pairs.pop() if b and len(litbm_rrs_pairs) > 0 else 'empty' for b in code]],
                         (7, 0))
    else:
        #print('Last row not empty')
        pass

with open('../../output/plate_arrangement.txt', 'w') as f:
    f.write('\n\n\n'.join([str(p) for p in plates]))


def plates_to_table(plates):
    data = []
    for i, plate in enumerate(plates):
        for j, row in enumerate(plate.grid):
            for k, pair in enumerate(row):
                data.append((i + 1, string.ascii_uppercase[j] + str(k + 1).zfill(2), *pair.split('/')))
    return pd.DataFrame(data, columns=['plate_id', 'plate_position', 'AD', 'DB'])


plates_to_table(plates).to_csv('../../output/plate_arrangement.tsv', sep='\t', index=False)

print(len(plates), 'plates')

n_empty_well_codes = len({frozenset(p.empty_wells()) for p in plates})
if n_empty_well_codes == len(plates):
    print('Success! Each plate has its own unique empty well code')
else:
    print(len(plates) - n_empty_well_codes, 'duplicates empty well codes')

TCF4 rotated
ZNF451 rotated
221 plates
55 duplicates empty well codes
