In [1]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy import stats
from colorir import *

In [2]:
orders = "Primates, Rodentia, Chiroptera, Artiodactyla, Carnivora, Lagomorpha, Perissodactyla, Pilosa, Macroscelidea, Sirenia, Eulipotyphla, Pholidota, Others"
orders = orders.split(", ")
sp_richness = [518, 2552, 1386, 551, 305, 98, 21, 10, 20, 5, 527, 8]
sp_total = 6111
sp_richness.append(6111 - sum(sp_richness))
data = dict(sorted(zip(orders, sp_richness), key=lambda t: -t[1]))

In [3]:
colors = Palette.load()
pal = StackPalette.load("set3")
pal = sorted(pal, key=hue_sort_key(gray_start=False))
pal += Grad([colors.strawberry, pal[-2]]).n_colors(3)
pal.append("rgba(0, 0, 0, 0.2)")
del pal[10:12]
del pal[8]
pal

[[48;2;251;128;114m   [0m[38;2;251;128;114m #fb8072[0m,
 [48;2;253;180;98m   [0m[38;2;253;180;98m #fdb462[0m,
 [48;2;255;237;111m   [0m[38;2;255;237;111m #ffed6f[0m,
 [48;2;255;255;179m   [0m[38;2;255;255;179m #ffffb3[0m,
 [48;2;179;222;105m   [0m[38;2;179;222;105m #b3de69[0m,
 [48;2;204;235;197m   [0m[38;2;204;235;197m #ccebc5[0m,
 [48;2;141;211;199m   [0m[38;2;141;211;199m #8dd3c7[0m,
 [48;2;128;177;211m   [0m[38;2;128;177;211m #80b1d3[0m,
 [48;2;188;128;189m   [0m[38;2;188;128;189m #bc80bd[0m,
 [48;2;252;90;141m   [0m[38;2;252;90;141m #fc5a8d[0m,
 [48;2;255;150;185m   [0m[38;2;255;150;185m #ff96b9[0m,
 [48;2;252;205;229m   [0m[38;2;252;205;229m #fccde5[0m,
 'rgba(0, 0, 0, 0.2)']

## Correlation plot with duplications

In [4]:
corr_data = {k: [v] for k, v in data.items()}
del corr_data["Others"]
corr_data["Pilosa"].append(2)
corr_data["Macroscelidea"].append(0)
corr_data["Sirenia"].append(1)
corr_data["Lagomorpha"].append(1)
corr_data["Rodentia"].append((2 + 1  + 3 + 7 + 7 + 4 + 6 + 6) / 8)
corr_data["Primates"].append(1)
corr_data["Eulipotyphla"].append(1)
corr_data["Chiroptera"].append((0 + 1 + 2 + 5) / 5)
corr_data["Perissodactyla"].append(3)
corr_data["Artiodactyla"].append((1 + 1 + 2 + 5 + 5 + 2) / 10)
corr_data["Pholidota"].append(1)
corr_data["Carnivora"].append((1 + 1 + 1 + 1 + 1 + 1 + 1) / 11)
corr_df = pd.DataFrame.from_dict(
    corr_data, 
    orient="index",
    columns=["# of species", "Avg. duplications"]
)
corr_df = corr_df.reset_index().rename(columns={"index": "Order"})
corr_df["Color"] = pal[:-1]

In [5]:
stats.pearsonr(corr_df["# of species"], corr_df["Avg. duplications"])

(0.6825426847643353, 0.014453410595422392)

In [6]:
X = sm.add_constant(corr_df["# of species"].values)
model = sm.OLS(corr_df["Avg. duplications"].values, X)
est = model.fit()
out = est.conf_int(alpha=0.05, cols=None)
pred = est.get_prediction(X).summary_frame()

In [7]:
x = list(corr_df["# of species"].values)
y_l = list(pred["mean_ci_lower"].values)
y_u = list(pred["mean_ci_upper"].values)
error_trace = go.Scatter(
    x=x + x[::-1],
    y=y_u + y_l[::-1],
    fill='toself',
    fillcolor='rgba(0, 0, 0, 0.1)',
    line=dict(color='rgba(0, 0, 0, 0)'),
    hoverinfo="skip"
)

In [8]:
# Get scatter traces
scat_traces = (error_trace,)
fig = px.scatter(
    corr_df,
    x="# of species", 
    y="Avg. duplications",
    color="Color",
    color_discrete_map="identity",
    hover_data=["Order"],
    trendline="ols",
    trendline_color_override=pal[-1],
    trendline_scope="overall"
)
fig.update_traces(marker_size=12, marker_line_width=0.5)
scat_traces += fig.data

## ERVs in orders

In [9]:
corr_df["Raw ERVs"] = [
    [822, 2072, 203, 4387, 1141, 757],
    [267, 212
],
    [783, 539, 78, 189, 26],
    [2115, 274],
    [1608, 834, 958, 817, 601, 323, 939, 625, 612, 359, 209, 246, 493],
    [198, 47, 103, 258],
    [204, 289],
    [191, 222],
    None,
    [51],
    None,
    [663]
]
corr_df["ERVs"] = [max(erv) if erv is not None else None for erv in corr_df["Raw ERVs"]]
erv_df = corr_df[corr_df["ERVs"].notna()]
erv_df.reset_index(inplace=True)

In [10]:
stats.pearsonr(erv_df["# of species"], erv_df["ERVs"])

(0.7908776934993673, 0.006444960520132772)

In [11]:
X = sm.add_constant(erv_df["# of species"].values)
model = sm.OLS(erv_df["ERVs"].values, X)
est = model.fit()
out = est.conf_int(alpha=0.05, cols=None)
pred = est.get_prediction(X).summary_frame()

In [12]:
x = list(erv_df["# of species"].values)
y_l = list(pred["mean_ci_lower"].values)
y_u = list(pred["mean_ci_upper"].values)
error_trace = go.Scatter(
    x=x + x[::-1],
    y=y_u + y_l[::-1],
    fill='toself',
    fillcolor='rgba(0, 0, 0, 0.1)',
    line=dict(color='rgba(0, 0, 0, 0)'),
    hoverinfo="skip"
)

In [13]:
erv_traces = (error_trace,)
erv_fig = px.scatter(
    erv_df,
    x="# of species",
    y="ERVs",
    color="Color",
    color_discrete_sequence=erv_df["Color"],
    trendline="ols",
    trendline_scope="overall",
    trendline_color_override=pal[-1],
    hover_data=["Order"],
    template="plotly_white"
)
erv_fig.update_traces(marker_size=12, marker_line_width=0.5)
erv_traces += erv_fig.data

fig = make_subplots(2, 1, shared_xaxes=True, x_title="# of species", vertical_spacing=0.03)
for trace in erv_traces:
    fig.add_trace(trace, row=1, col=1)
for trace in scat_traces:
    fig.add_trace(trace, row=2, col=1)
    
fig.update_layout(
    showlegend=False,
    width=800,
    height=800,
    template="plotly_white"
)
fig.update_yaxes(title="ERVs", row=1, col=1)
fig.update_yaxes(title="Avg. # of duplications", row=2, col=1)
fig.show("iframe")
fig.write_image("diversity_corr.pdf")

## Waffle chart

In [14]:
w = 21
h = 9
st = w * h
m = np.zeros(st, dtype=int)

new_data = {}
other_l, other_v = [], 0
# Filter small groups
for k, v in list(data.items()):
    sqs = round(st * v / sp_total)
    if sqs < 1:
        other_l.append(k[:3])
        other_v += v
    else:
        new_data[k] = v
new_data[" + ".join(other_l)] = other_v
        
def sort_key(k):
    if k == "Others":
        return 2
    if "+" in k:
        return 1
    return -new_data[k]
        
new_data = {k: new_data[k] for k in sorted(new_data, key=sort_key)}

i = 0
for j, k in zip(range(len(new_data), 0, -1), new_data.keys()):
    v = new_data[k]
    sqs = round(st * v / sp_total)
    m[i:i + sqs] = j
    print(j, k, '\t', sqs)
    i += sqs
m = m.reshape((w, h))
# Gotta find a width and height balance such that sum of rounded numbers == st
if i != st:
    raise ValueError(f"i == {i} != st ({st})")

11 Rodentia 	 79
10 Chiroptera 	 43
9 Artiodactyla 	 17
8 Eulipotyphla 	 16
7 Primates 	 16
6 Carnivora 	 9
5 Lagomorpha 	 3
4 Perissodactyla 	 1
3 Macroscelidea 	 1
2 Pil + Pho + Sir 	 1
1 Others 	 3


In [15]:
color_cats = np.max(m)
w_pal = pal[:9] + [pal[-3], pal[-1]]
colorscale = []
for i, c in enumerate(w_pal[::-1]):
    v1 = (i / color_cats, c)
    v2 = ((i + 1) / color_cats, c)
    colorscale += [v1, v2]

In [16]:
ticks = np.linspace(1, color_cats, 2 * color_cats + 1)
labels = list(new_data.keys())[::-1]
ticktexts = [labels[i // 2] if i % 2 == 1 else "" for i in range(0, 2 * color_cats + 1)]
colorbar = go.heatmap.ColorBar(
    title="Order",
    lenmode="pixels",
    len=15 * (color_cats + 2),
    thickness=15,
    tickvals=ticks,
    ticktext=ticktexts,
    tickmode="array",
    tickfont_size=8,
    yanchor="top",
    y=1,
    ypad=0
    
)
fig = go.Figure(go.Heatmap(
    z=m,
    colorscale=colorscale,
    colorbar=colorbar,
    xgap=3,
    ygap=3
))
fig.update_layout(
    width=350,
    xaxis=dict(showgrid=False, showticklabels=False),
    yaxis=dict(showgrid=False, showticklabels=False, scaleanchor="x"),
    plot_bgcolor="rgba(0, 0, 0, 0)"
)

def add_lines_y(ranges, colors, texts):
    for rng, color, text in zip(zip(ranges, ranges[1:]), colors, texts):
        pos = dict(x0=-0.8, y0=rng[0] - 0.3, x1=-0.8, y1=rng[1] - 0.7)
        fig.add_shape(
            type="line",
            line=dict(color=color, width=1),
            **pos
        )
        fig.add_annotation(
            text=text,
            xanchor="right",
            x=pos["x0"],
            y=(pos["y0"] + pos["y1"]) / 2,
            font_size=8,
            align="right",
            showarrow=False,
            xshift=-4
        )
        
def add_lines_x(ranges, colors, texts):
    for rng, color, text in zip(zip(ranges, ranges[1:]), colors, texts):
        pos = dict(x0=rng[0] - 0.3, y0=len(m) - 0.2, x1=rng[1] - 0.7, y1=len(m) - 0.2)
        fig.add_shape(
            type="line",
            line=dict(color=color, width=1),
            **pos
        )
        fig.add_annotation(
            text=text,
            yanchor="bottom",
            x=(pos["x0"] + pos["x1"]) / 2,
            y=pos["y0"],
            font_size=8,
            align="right",
            showarrow=False
        )

# Remove comment to add more info to chart
# add_lines_y([0, 9, 14, 16, 18, 19, 20], pal[::-1][1:], list(new_data.values())[:6])
# add_lines_x([0, 3, 4, 5, 6, 9], pal[:5][::-1], list(new_data.values())[6:])
fig.show("iframe")

In [17]:
fig.write_image("waffle_chart.pdf")

## ERVs TRIM5 dup correlation

In [55]:
df = pd.read_csv("../ervs_trim5.csv", sep=';', index_col=0)
df.sort_values("ERVs", inplace=True)
dfc = pd.read_csv("../cluster_seqs.csv", sep=';', index_col=0)
dfc = dfc[dfc["TreeStatus"] == "INCLUDED"]
for sp in df.index:
    df.loc[sp, "TRIM count"] = sum(dfc["Species"] == sp)
    
df = pd.DataFrame({
    "ERVs": [0, 1, 2, 3, 4, 5],
    "TRIM count": [0, 1, 8, 27, 64, 125],
    "TRIM5 count": [0, 2, 4, 6, 8, 10],
    "Order": df["Order"][:6]
})

data_ = dict(data)
del data_["Others"]
pal_map = dict(zip(data_.keys(), pal))
for sp in df.index:
    df.loc[sp, "Color"] = pal_map[df.loc[sp, "Order"]]

In [56]:
fig = make_subplots(2, 1, shared_xaxes=True, x_title="% of ERVs in the genome", vertical_spacing=0.03)

for row, df_col in enumerate(["TRIM count", "TRIM5 count"], 1):
    X = sm.add_constant(df["ERVs"].values)
    y = df[df_col].values
    model = sm.GLM(y, X, family=sm.families.Poisson()).fit()
    pred = model.get_prediction(X).summary_frame()
    print(model.summary())
    
    scatter = go.Scatter(
        x=df["ERVs"],
        y=df[df_col],
        mode="markers",
        marker_color=df["Color"],
        marker_size=8,
        marker_line_width=0.5,
        hovertext=df.index
    )
        
    regr = go.Scatter(
        x=df["ERVs"],
        y=pred["mean"],
        mode="lines",
        line_color="rgba(0, 0, 0, 0.2)"
    )

    x = list(df["ERVs"])
    y_l = list(pred["mean_ci_lower"])
    y_u = list(pred["mean_ci_upper"])
    regr_shadow = go.Scatter(
        x=x + x[::-1],
        y=y_u + y_l[::-1],
        fill='toself',
        fillcolor='rgba(0, 0, 0, 0.1)',
        line=dict(color='rgba(0, 0, 0, 0)'),
        hoverinfo="skip"
    )
    fig.add_trace(scatter, row=row, col=1)
    fig.add_trace(regr, row=row, col=1)
    fig.add_trace(regr_shadow, row=row, col=1)

fig.update_layout(
    template="plotly_white",
    width=800,
    height=800,
    showlegend=False
)
fig.update_yaxes(title="TRIM count", row=1, col=1)
fig.update_yaxes(title="TRIM5 count", row=2, col=1)
fig.show("iframe")

                 Generalized Linear Model Regression Results                  
Dep. Variable:                      y   No. Observations:                    6
Model:                            GLM   Df Residuals:                        4
Model Family:                 Poisson   Df Model:                            1
Link Function:                    Log   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -16.366
Date:                Tue, 30 Aug 2022   Deviance:                       8.9869
Time:                        15:22:25   Pearson chi2:                     6.97
No. Iterations:                     6   Pseudo R-squ. (CS):              1.000
Covariance Type:            nonrobust                                         
                 coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
const          0.3589      0.301      1.191      0.2

In [20]:
fig.write_image("erv_dup_corr.pdf")