# **DEPENDENCIES**

In [71]:
# Local import
# import sys
# sys.path.append('..')

# Import on colab
%pip install git+https://github.com/adityaprakash-work/DreamWalker.git

In [None]:
import os
import glob
import random

import numpy as np
import torch
from torch import optim
from torch.nn import functional as F
from torch.optim import lr_scheduler
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset, random_split

from google.colab import drive
import dreamwalker as dw
from dreamwalker.pytorch_generative import models, trainer
from dreamwalker.models.brain import ConformerEEGEncoder

drive.mount('/content/drive')

In [None]:
%load_ext tensorboard

# **DATASET**

In [None]:
# Run this cell to load from an online source
dataset_url = "https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000/data"
dataset_dir = "/content/dataset"
dw.utils.datasets.download(dataset_url, dataset_dir)

In [None]:
train_dir = "/content/dataset/imagenetmini-1000/imagenet-mini/train"
valid_dir = "/content/dataset/imagenetmini-1000/imagenet-mini/val"
if valid_dir is None:
    dataset = dw.utils.datasets.ImageStream(train_dir, ext="JPEG")
    train_loader, valid_loader = dw.utils.datasets.get_loaders(
        dataset, return_valid=True, valid_size=0.2
    )

else:
    train_dataset = dw.utils.datasets.ImageStream(train_dir, ext="JPEG")
    valid_dataset = dw.utils.datasets.ImageStream(valid_dir, ext="JPEG")
    train_loader = dw.utils.datasets.get_loaders(train_dataset, batch_size=16)
    valid_loader = dw.utils.datasets.get_loaders(valid_dataset, batch_size=16)

In [None]:
eeg_data_path = "/content/drive/MyDrive/Brain2Image/raw_eeg2/"
img_data_path = "/content/drive/MyDrive/Brain2Image/sqr_images2/"
all_classes = np.unique([i.split("_")[0] for i in os.listdir(img_data_path)])

In [None]:
use_classes = list(np.random.choice(all_classes, 40))
use_eegs = [
    p for c in use_classes for p in glob.glob(f"{eeg_data_path}*{c}*")
]
random.shuffle(use_eegs)
use_imgs = [
    os.path.join(img_data_path, "_".join(os.path.basename(p).split("_")[2:4]))
    for p in use_eegs
]
use_clsl = [
    os.path.basename(p).split("_")[0] for p in use_imgs
]

In [None]:
class DreamWalkerDataset(Dataset):
    def __init__(self, eeg_data_path, img_data_path, use_eegs, transform=None):
        self.eeg_data_path = eeg_data_path
        self.img_data_path = img_data_path
        self.use_eegs = use_eegs
        self.transform = transform

    def __len__(self):
        return len(self.use_eegs)

    def __getitem__(self, idx):
        eeg_path = self.use_eegs[idx]
        img_path = os.path.join(
            self.img_data_path, "_".join(os.path.basename(eeg_path).split("_")[2:4])
        )

        eeg_data = np.load(eeg_path).T
        eeg_data = np.expand_dims(eeg_data, 0)
        img_data = np.load(img_path)

        if self.transform:
            img_data = self.transform(img_data)

        return eeg_data, img_data

In [None]:
transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ]
)

dataset = DreamWalkerDataset(
    eeg_data_path, img_data_path, use_eegs, transform=transform
)

lt = int(0.8 * len(dataset))
lv = len(dataset) - lt

td, vd = random_split(dataset, [lt, lv])

dw_train_loader = DataLoader(td, batch_size=16, shuffle=True)
dw_valid_loader = DataLoader(vd, batch_size=16, shuffle=True)

# **VQVAE MODEL**

In [None]:
vqvae_model = models.VectorQuantizedVAE2(
    in_channels=3,
    out_channels=3,
    hidden_channels=128,
    n_residual_blocks=2,
    residual_channels=64,
    n_embeddings=512,
    embedding_dim=64,
)

vqvae_optimizer = optim.Adam(vqvae_model.parameters(), lr=2e-4)
vqvae_scheduler = lr_scheduler.MultiplicativeLR(
    vqvae_optimizer, lr_lambda=lambda _: 0.999977
)


def vqvae_loss_fn(x, _, preds):
    preds, vq_loss = preds
    recon_loss = F.mse_loss(preds, x)
    loss = recon_loss + 0.25 * vq_loss

    return {
        "vq_loss": vq_loss,
        "reconstruction_loss": recon_loss,
        "loss": loss,
    }

vqvae_model_trainer = trainer.Trainer(
    model=vqvae_model,
    loss_fn=vqvae_loss_fn,
    optimizer=vqvae_optimizer,
    train_loader=train_loader,
    eval_loader=valid_loader,
    lr_scheduler=vqvae_scheduler,
    log_dir="/content/logs/vqvae0",
    n_gpus=1,
)

vqvae_model_trainer.restore_checkpoint(7)

# **ALIGNMENT MODEL**

In [None]:
class AlignmentModel(torch.nn.Module):
    def __init__(self, brain_module, vqvae2_module, vgrad=False):
        super().__init__()
        self.brain_module = brain_module
        self.vqvae_module = vqvae2_module
        if not vgrad:
            self.vqvae_module.requires_grad_(False)
        

    def forward(self, x, y):
        zb, zt = self.brain_module(x)

        encoded_b = self.vqvae_module._encoder_b(y)
        encoded_t = self.vqvae_module._encoder_t(encoded_b)

        # Cosine loss on the encoded features
        talgn_loss = 1 - F.cosine_similarity(zt, encoded_t).mean()
        balgn_loss = 1 - F.cosine_similarity(zb, encoded_b).mean()
        algn_loss = talgn_loss + balgn_loss

        quantized_t, vq_loss_t = self.vqvae_module._quantizer_t(zt)
        quantized_b, vq_loss_b = self.vqvae_module._quantizer_b(zb)

        decoded_t = self.vqvae_module._decoder_t(quantized_t)
        xhat = self.vqvae_module._decoder_b(
            torch.cat((self.vqvae_module._conv(decoded_t), quantized_b), dim=1)
        )

        vq_loss = 0.5 * (vq_loss_b + vq_loss_t) + F.mse_loss(decoded_t, encoded_b)

        return xhat, vq_loss, algn_loss

# **TRAINING**

In [None]:
algn_model = AlignmentModel(
    dw.models.brain.ConformerEEGEncoder(),
    vqvae_model,
)

algn_optimizer = optim.Adam(algn_model.brain_module.parameters(), lr=2e-4)
algn_scheduler = lr_scheduler.MultiplicativeLR(
    algn_optimizer, lr_lambda=lambda _: 0.999977
)


def algn_loss_fn(x, y, preds):
    xhat, _, algn_loss = preds
    recon_loss = F.mse_loss(preds, y)
    loss = algn_loss + 0.25 * recon_loss

    return {
        "algn_loss": algn_loss,
        "reconstruction_loss": recon_loss,
        "loss": loss,
    }


agn_model_trainer = trainer.Trainer(
    model=algn_model,
    loss_fn=algn_loss_fn,
    optimizer=algn_optimizer,
    train_loader=dw_train_loader,
    eval_loader=dw_valid_loader,
    lr_scheduler=algn_scheduler,
    log_dir="/content/logs/algn0",
    n_gpus=1,
)

In [None]:
def make_grid_ovsr(original, reconstructions):
    num_samples=original.shape[0]
    num_rows = int(np.ceil(np.sqrt(num_samples)))
    grid_o = make_grid(original, nrow=num_rows, normalize=True)
    grid_r = make_grid(reconstructions, nrow=num_rows, normalize=True)
    grid = torch.cat([grid_o, grid_r], dim=-1)
    return grid

def recplt_monitor(model_trainer):
    model_trainer.model.eval()
    x, y = next(iter(model_trainer.eval_loader))
    x = x.to(model_trainer.device)
    x_recon, _, _ = model_trainer.model(x)
    x_recon = x_recon.cpu().detach()
    model_trainer._summary_writer.add_image(
        "Reconstruction Fidelity",
        make_grid_ovsr(y, x_recon),
        model_trainer._step,
    )
    model_trainer.model.train()

algn_model.interleaved_train_and_eval(7, arbitrary_monitors=[recplt_monitor])