In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from timebudget import timebudget
from itertools import chain
from rich import print as pprint

import torch
import torch.nn as nn

from sourcesep.sim import SimData
from sourcesep.utils.config import load_config
from sourcesep.models.baseunet import BaseUnet

sns.set_theme(font_scale=0.8)
%config InlineBackend.figure_format='retina'

In [None]:
paths = load_config(dataset_key='all')
sim = SimData(T=1024, cfg_path=paths['root'] / "sim_config.toml")
dat = sim.compose()
print(dat.keys())
print(dat['O'].shape)
print(sim.cfg['indicator'].keys())
print(sim.cfg['laser'].keys())

In [None]:
import torch
import torch.nn as nn

# input is of shape (T=1000, J=5, L=300) (time, laser, lambda)
# output is of shape (T=1000, I=8) (time, sources={indicators, hemodynamics, noise})
# Conv3d input is expected with shape (N, C=1 ,D=time ,H=lasers ,W=lambda)

T = 1024
J = 5
L = 300

mse_A = nn.L1Loss(reduction='mean')
mse_H_ox = nn.L1Loss(reduction='mean')
mse_H_dox = nn.L1Loss(reduction='mean')

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
model = BaseUnet(in_channels=1500, out_channels=5)
model.to(device)
print(device)

In [None]:
n_steps = 500
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
dat = sim.compose()

# input reshaping
input = dat['O'][np.newaxis, ...]  # insert 'batch' axis
input = np.reshape(input, newshape=(1, 1024, 1500))
input = np.swapaxes(input, 1,2)
input = torch.as_tensor(input, dtype=torch.float32, device=device)

# targets
A = torch.as_tensor(dat['A'], dtype=torch.float32, device=device)
H_ox = torch.as_tensor(dat['H_ox'], dtype=torch.float32, device=device)
H_dox = torch.as_tensor(dat['H_dox'], dtype=torch.float32, device=device)

In [None]:
plot_freq = 500
n_steps = 5000

def tonumpy(x):
    return x.cpu().detach().numpy()

for step in range(n_steps):
    dat = sim.compose()

    # input reshaping
    input = dat['O'][np.newaxis, ...]  # insert 'batch' axis
    input = np.reshape(input, newshape=(1, 1024, 1500))
    input = np.swapaxes(input, 1,2)
    input = torch.as_tensor(input, dtype=torch.float32, device=device)

    # targets
    A = torch.as_tensor(dat['A'], dtype=torch.float32, device=device)
    H_ox = torch.as_tensor(dat['H_ox'], dtype=torch.float32, device=device)
    H_dox = torch.as_tensor(dat['H_dox'], dtype=torch.float32, device=device)

    # model forward pass
    output = model(input.to(device))

    # loss
    Ar = torch.squeeze(output[0, 0:3, ...])
    H_oxr = torch.squeeze(output[0, 3, ...])
    H_doxr = torch.squeeze(output[0, 4, ...])
    # avoid dealing with boundary issues for now

    loss = 10*mse_A(Ar.T[256:-256, 0], A[256:-256, 0]) \
        + 10*mse_A(Ar.T[256:-256, 1], A[256:-256, 1]) \
        + 10*mse_A(Ar.T[256:-256, 2], A[256:-256, 2]) \
        + 0*mse_H_ox(H_oxr, H_ox) \
        + 0*mse_H_ox(H_doxr, H_dox)

    if (step+1) % plot_freq == 0:
        print(f'Step: {step} -- Loss: {tonumpy(loss):0.4f}')
        f, ax = plt.subplots(3, 1, figsize=(4, 6))
        for i in range(3):
            ax[i].plot(tonumpy(A)[256:-256, i], '-b')
            ax[i].plot(tonumpy(Ar).T[256:-256, i], '-r', alpha=0.5)
            #ax[i].set(ylim=(0.5,1.5))

        f, ax = plt.subplots(2,1,figsize=(4,4))
        ax[0].plot(tonumpy(H_ox),'-b')
        ax[0].plot(tonumpy(H_oxr.T),'-r',alpha=0.5)
        ax[0].set(title='ox')

        ax[1].plot(tonumpy(H_dox),'-b')
        ax[1].plot(tonumpy(H_doxr.T),'-r',alpha=0.5)
        ax[1].set(title='dox')
        plt.show()

    else:
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
