# UMAP visualization

We need to add plotly to environment.yaml file in case tutor wants to run the code

## Setup

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import seaborn as sb

%matplotlib inline

In [3]:
input_dir = "data/raw"

In [6]:
umap_braycurtis = pd.read_csv(
                 f"{input_dir}/umap_pcoa_export/umap_braycurtis/ordination.txt",
                 sep='\t',
                 index_col=0,
                 skiprows= 9,
                 header=None
                 )

In [5]:
metadata = pd.read_csv(
                    f"{input_dir}/metadata.tsv",
                    sep='\t',
                    index_col=0,
                    header=0
                    )

## No filtering

In [20]:
umap_braycurtis_metadata_merged = pd.merge(umap_braycurtis, metadata, how='left', left_index=True, right_index=True)
umap_braycurtis_metadata_merged["infant_id"] = umap_braycurtis_metadata_merged["infant_id"].astype("category")

### Infant clustering

In [40]:
fig = px.scatter_3d(umap_braycurtis_metadata_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='infant_id',
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Timepoint clustering

In [41]:
fig = px.scatter_3d(umap_braycurtis_metadata_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='timepoint',
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Infant + Timepoint clustering

In [42]:
fig = px.scatter_3d(umap_braycurtis_metadata_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='infant_id',
                    symbol='timepoint',
                    symbol_map={
                    "2 months": "circle",
                    "4 months": "cross",
                    "6 months": "diamond"},
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

## Infants that have samples at 2, 4 and 6 months

### Filtering

In [16]:
#filtering dataframe for infants that have measurements at timepoints 2, 4 and 6 months
infants_246 = metadata.groupby("infant_id")["timepoint"].nunique().eq(3)
infants_246_index = infants_246[infants_246].index
metadata_246 = metadata[metadata["infant_id"].isin(infants_246_index)]

print("Number of infant in original metadata: ", metadata["infant_id"].nunique())
print("Number of infants left after filtering: ", metadata_246["infant_id"].nunique())

print("Number of samples per infant per timepoint: ")
metadata_246.groupby(["infant_id", "timepoint"]).size().unstack()

Number of infant in original metadata:  17
Number of infants left after filtering:  5
Number of samples per infant per timepoint: 


timepoint,2 months,4 months,6 months
infant_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
3,4,2,2
4,10,4,3
5,7,3,3
7,4,2,4
10,5,1,1


In [17]:
umap_braycurtis_metadata_246_merged = pd.merge(umap_braycurtis, metadata_246, how='inner', left_index=True, right_index=True)
umap_braycurtis_metadata_246_merged["infant_id"] = umap_braycurtis_metadata_246_merged["infant_id"].astype("category")


### Infant clustering

In [38]:
fig = px.scatter_3d(umap_braycurtis_metadata_246_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='infant_id',
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Timepoint clustering

In [37]:
fig = px.scatter_3d(umap_braycurtis_metadata_246_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='timepoint',
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Infant + timepoint clustering

In [61]:
fig = px.scatter_3d(umap_braycurtis_metadata_246_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='infant_id',
                    symbol='timepoint',
                    symbol_map={
                    "2 months": "circle",
                    "4 months": "cross",
                    "6 months": "diamond"},
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Sleep Quality Scores

In [58]:
fig = px.scatter_3d(umap_braycurtis_metadata_246_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='sleep_quality',
                    symbol='infant_id',
                    # symbol_map={
                    # "2 months": "circle",
                    # "4 months": "cross",
                    # "6 months": "diamond"},
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

## All infants, max 3 samples per timepoint 

### Filtering

In [47]:
metadata.groupby(["infant_id", "timepoint"]).size().unstack()
metadata_samples3 = metadata.loc[metadata["sample_number"] <= 3, :]
metadata_samples3.groupby(["infant_id", "timepoint"]).size().unstack()

timepoint,2 months,4 months,6 months
infant_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,3.0,,
2,2.0,,3.0
3,3.0,2.0,2.0
4,3.0,3.0,3.0
5,3.0,3.0,3.0
6,2.0,,
7,3.0,2.0,3.0
8,3.0,3.0,
9,3.0,2.0,
10,3.0,1.0,1.0


In [49]:
umap_braycurtis_metadata_samples3_merged = pd.merge(umap_braycurtis, metadata_samples3, how='inner', left_index=True, right_index=True)
umap_braycurtis_metadata_samples3_merged["infant_id"] = umap_braycurtis_metadata_samples3_merged["infant_id"].astype("category")

### Infant clustering

In [50]:
fig = px.scatter_3d(umap_braycurtis_metadata_samples3_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='infant_id',
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Timepoint clustering

In [62]:
fig = px.scatter(umap_braycurtis_metadata_samples3_merged,
                    x=1,
                    y=2,
                    #z=3,
                    color='timepoint',
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

### Timepoint and infant

In [51]:
fig = px.scatter_3d(umap_braycurtis_metadata_samples3_merged,
                    x=1,
                    y=2,
                    z=3,
                    color='infant_id',
                    symbol='timepoint',
                    symbol_map={
                    "2 months": "circle",
                    "4 months": "cross",
                    "6 months": "diamond"},
                    opacity=0.7)

fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()