In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

from lurz2020.models.models import se2d_fullgaussian2d
from csng.utils import get_corr, plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss

from data_orig import prepare_spiking_data_loaders
# from data import prepare_data_loaders
from data import prepare_data_loaders, SyntheticDataset, BatchPatchesDataLoader

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "train_path": os.path.join(DATA_PATH, "datasets", "train"),
        "val_path": os.path.join(DATA_PATH, "datasets", "val"),
        "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
        "image_size": [50, 50],
        "crop": False,
        "batch_size": 8,
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}
print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

## Data

In [None]:
data_loaders = prepare_spiking_data_loaders(**config["data"])
oracle_dataloader = data_loaders["test"]

In [None]:
### show data
data_sample = next(iter(data_loaders["train"]["spiking"]))
stim, resp = data_sample[0].float(), data_sample[1]

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(121)
ax.imshow(stim[0].cpu().squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(122)
ax.imshow(resp[0].cpu().view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

In [None]:
crop_win = (slice(15, 35), slice(15, 35))
crop_stim = lambda x: x[..., crop_win[0], crop_win[1]]

## Synthetic data

In [None]:
# from data_orig import prepare_spiking_data_loaders
# from lurz2020.models.models import se2d_fullgaussian2d

# spiking_data_loaders_config = {
#     "train_path": os.path.join(DATA_PATH, "datasets", "train"),
#     "val_path": os.path.join(DATA_PATH, "datasets", "val"),
#     "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
#     "image_size": [50, 50],
#     "crop": False,
#     "batch_size": 32,
# }
# encoder_config = {
#     "init_mu_range": 0.55,
#     "init_sigma": 0.4,
#     "input_kern": 19,
#     "hidden_kern": 17,
#     "hidden_channels": 32,
#     "gamma_input": 1.0,
#     "gamma_readout": 2.439,
#     "grid_mean_predictor": None,
#     "layers": 5
# }

In [None]:
# ### encoder
# encoder = se2d_fullgaussian2d(
#     **encoder_config,
#     dataloaders=data_loaders,
#     seed=2,
# )

# ### load pretrained core
# pretrained_core = torch.load(
#     os.path.join(DATA_PATH, "models", "spiking_scratch_tunecore_68Y_model.pth"),
#     map_location=torch.device("cuda")
# )
# encoder.load_state_dict(pretrained_core, strict=True)
# encoder.to(config["device"])
# _ = encoder.eval()

In [None]:
syn_data_imgs_path = os.path.join(os.environ["DATA_PATH"], "sensorium22", "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6", "data", "images")
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "responses_mean_from_syn_dataset.npy"))).float().to(config["device"])
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "responses_std_from_syn_dataset.npy"))).float().to(config["device"])
syn_dataset = SyntheticDataset(
    data_dir=syn_data_imgs_path,
    patch_size=config["data"]["image_size"][0],
    overlap=15,
    encoder=encoder,
    expand_stim_for_encoder=False,
    # stim_transform=transforms.Normalize(
    #     mean=114.457,
    #     std=51.356,
    # ),
    # resp_transform=transforms.Lambda(
    #     lambda x: (x - resp_mean) / resp_std
    # ),
    stim_transform=None,
    resp_transform=None,
    device=config["device"],
)
_dataloader = DataLoader(syn_dataset, batch_size=2, shuffle=True)
syn_dataloader = BatchPatchesDataLoader(_dataloader)

In [None]:
### show data
syn_stim, syn_resp = next(iter(syn_dataloader))

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(121)
ax.imshow(syn_stim.cpu()[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(122)
ax.imshow(syn_resp.cpu()[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

## Encoder

In [None]:
config["encoder_inversion"] = {
    "model": {
        "init_mu_range": 0.55,
        "init_sigma": 0.4,
        "input_kern": 19,
        "hidden_kern": 17,
        "hidden_channels": 32,
        "gamma_input": 1.0,
        "gamma_readout": 2.439,
        "grid_mean_predictor": None,
        "layers": 5
    },
    "n_inits": 15,
    "n_steps": 8000,
    "opter_cls": torch.optim.Adam,
    "opter_kwargs": {
        "lr": 0.3,
    },
    "loss_fn": nn.MSELoss(),
    # "loss_fn": nn.L1Loss(),
}

In [None]:
### encoder
encoder = se2d_fullgaussian2d(
    **config["encoder_inversion"]["model"],
    dataloaders=data_loaders,
    seed=2,
)  # Use data loaders with 50x50 images

### load pretrained core
pretrained_core = torch.load(
    os.path.join(DATA_PATH, "models", "spiking_scratch_tunecore_68Y_model.pth"),
    map_location=torch.device("cuda")
)
encoder.load_state_dict(pretrained_core, strict=True)
encoder.to(config["device"])
_ = encoder.eval()

## Run inversion

In [None]:
target_stim = syn_stim.float().to(config["device"])
target_resp = syn_resp.float().to(config["device"])

In [None]:
plot_comparison(
    target=target_stim[:8].cpu(),
    pred=stim_pred[:8,:, 10:40, 10:40].detach().cpu(),
)

In [None]:
stim_preds = []

for init_i in range(config["encoder_inversion"]["n_inits"]):
    print(f"Init {init_i}")

    ### init decoded img 
    stim_pred = torch.rand_like(target_stim, device=config["device"]) * 100.
    # stim_pred = torch.zeros_like(stim, device=config["device"])
    stim_pred = stim_pred.requires_grad_(True)
    opter = config["encoder_inversion"]["opter_cls"](
        [stim_pred], **config["encoder_inversion"]["opter_kwargs"],
    )

    loss_history = []
    for step_i in range(config["encoder_inversion"]["n_steps"]):
        opter.zero_grad()
        resp_pred = encoder(stim_pred).float()
        loss = config["encoder_inversion"]["loss_fn"](resp_pred, target_resp)
        loss.backward()
        opter.step()

        loss_history.append(loss.item())

        if step_i % 500 == 0:
            print(f"Step {step_i}: {loss.item():.3f}")
            ### plot reconstruction samples
            plot_comparison(
                target=crop_stim(target_stim[:8].cpu()),
                pred=crop_stim(stim_pred[:8].detach().cpu()),
            )
            plot_comparison(
                target=target_stim[:8].cpu(),
                pred=stim_pred[:8].detach().cpu(),
            )

    ### plot loss history
    plt.plot(loss_history)
    plt.show()
    
    stim_preds.append(stim_pred.detach().cpu())

In [None]:
stim_pred = torch.stack(stim_preds).mean(0)
plot_comparison(
    target=crop_stim(target_stim[:8].cpu()),
    pred=crop_stim(stim_pred[:8].detach().cpu()),
)

In [None]:
### plot reconstruction samples
plot_comparison(
    target=crop_stim(stim)[:8].cpu(),
    pred=crop_stim(stim_pred)[:8].detach().cpu(),
)