# Data visualisation
---

Here I want to actually see the images and what goes into the model, as a sense check and to test how I'll see the outputs as well.

## Setup

### Import libraries

In [None]:
import os
import numpy as np
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd

In [None]:
from dataset_multitask import create_dataset

### Set parameters

In [None]:
pio.templates.default = "plotly_dark"
rgb_bands = [3, 2, 1]
full_img_size = 600
subplot_img_size = 300

In [None]:
data_path = "/Users/andreferreira/Documents_Offline/Datasets/WorkResearch/MultitaskNeurIPS2021"

### Define auxilary functions

In [None]:
def process_image_for_plot(img):
    img_rgb = img[rgb_bands].numpy().transpose(1, 2, 0)
    img_rgb = np.clip(img_rgb * 255, 0, 255)
    return img_rgb

## Load the data

In [None]:
reg_data = pd.read_csv(os.path.join(data_path, "reg_co2_data.csv"))
reg_data

In [None]:
data_train = create_dataset(
    datadir=data_path,
    seglabeldir=os.path.join(data_path, "segmentation_labels", "training/"),
    reg_data=reg_data,
    mult=1,
    train=True,
    channels=list(range(12)),
)

In [None]:
data_iter = iter(data_train)

In [None]:
samples = [next(data_iter) for i in range(4)]

## Visualise the data

### Just the images

In [None]:
samples[0].keys()

In [None]:
img_rgb = process_image_for_plot(samples[0]["img"])
img_rgb

In [None]:
img_rgb.shape

In [None]:
px.imshow(img_rgb, height=full_img_size, width=full_img_size)

In [None]:
fig = make_subplots(rows=len(samples) // 2, cols=2)
[
    fig.add_trace(go.Image(z=process_image_for_plot(s["img"])), row=(idx // 2) + 1, col=(idx % 2) + 1) 
    for idx, s in enumerate(samples)
]
fig.layout.height = subplot_img_size * len(samples) // 2
fig.layout.width = subplot_img_size * 2
fig

### Plume masks

In [None]:
samples[0]["fpt"].shape

In [None]:
fig = px.imshow(process_image_for_plot(samples[0]["img"]), height=full_img_size, width=full_img_size)
fig.add_trace(go.Contour(z=samples[0]["fpt"], showscale=False,
                         contours=dict(start=0, end=1, size=2, coloring="lines"),
                         line_width=4))
fig

In [None]:
for idx in range(len(samples)):
    fig = px.imshow(process_image_for_plot(samples[idx]["img"]), height=full_img_size, width=full_img_size)
    fig.add_trace(go.Contour(z=samples[idx]["fpt"], showscale=False,
                            contours=dict(start=0, end=1, size=2, coloring="lines"),
                            line_width=4))
    display(fig)