In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import time

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torchdiffeq
import torchsde
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torch.autograd import grad
from geomloss import SamplesLoss
from pykeops.torch import LazyTensor

from torchcfm.conditional_flow_matching import *
from torchcfm.models.unet import UNetModel

os.environ["WANDB_API_KEY"] = "ab7ab794a522f2467a9c6cb33f3a0220488f3ee7" # Change to your W&B profile if you need it
os.environ["WANDB_MODE"] = "online"

In [None]:
import wandb
import random

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="CFM_EEG2ECoG",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 2e-4,
    "architecture": "SB_CFM",
    "dataset": "EEG2ECoG",
    "epochs": 250,
    "Lambda":25,
    "DiscFilters": 48,
    "GenFilters": 96
    }
)

# Dataset preprocessing

In [None]:
from torchvision import transforms

class TopoMapDataset(Dataset):
    def __init__(self, data_directory, ids, transform=None):
        self.transform = transform
        self.data = [] 

        for subject_id in ids:
            eeg_file_path = os.path.join(data_directory, f"response_gen_topo_eeg_{subject_id}.png")
            ecog_file_path = os.path.join(data_directory, f"response_gen_topo_ecog_{subject_id}.png")

            if os.path.exists(eeg_file_path) and os.path.exists(ecog_file_path):
                eeg_image = Image.open(eeg_file_path)
                ecog_image = Image.open(ecog_file_path)

                if self.transform:
                    eeg_image = self.transform(eeg_image)
                    ecog_image = self.transform(ecog_image)

                self.data.append((subject_id, eeg_image, ecog_image))

    def __len__(self):
        return len(self.data)

    def convert_to_four_channels(self, image):
        # Convert single-channel image to four channels
        if image.mode == 'L':
            image = image.convert('RGBA')
        return image

    def __getitem__(self, idx):
        subject_id, eeg_image, ecog_image = self.data[idx]

        # Convert to four channels if needed
        eeg_image = self.convert_to_four_channels(eeg_image)
        ecog_image = self.convert_to_four_channels(ecog_image)

        data_dict = {
            "subject_id": subject_id,
            "eeg_image": eeg_image,
            "ecog_image": ecog_image
        }

        return data_dict

In [None]:
file_names_eeg = [f for f in os.listdir(data_directory) if f.startswith("response_gen_topo_eeg")]
file_names_ecog = [f for f in os.listdir(data_directory) if f.startswith("response_gen_topo_ecog")]
file_ids_eeg = list(set([file.split('_')[0] for file in file_names_eeg]))
file_ids_ecog = list(set([file.split('_')[0] for file in file_names_ecog]))

file_ids =set(file_ids_eeg) & set(file_ids_ecog)

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

In [None]:
unique_ids = [f.split("_")[-1].split(".")[0] for f in file_names_eeg]
dataset = TopoMapDataset(data_directory=data_directory, ids=unique_ids, transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

batch_size = 10
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
from torchvision.transforms import ToPILImage
index_to_display = 10
data_item = train_dataset[index_to_display]

eeg_image = data_item['eeg_image']
ecog_image = data_item['ecog_image']

to_pil = ToPILImage()
eeg_pil_image = to_pil(eeg_image)
ecog_pil_image = to_pil(ecog_image)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(eeg_pil_image)
axes[0].set_title('EEG Image')

axes[1].imshow(ecog_pil_image)
axes[1].set_title('ECoG Image')

plt.show()


# Exact Optimal Transport Conditional Flow Matching

In [None]:
sigma = 0.1
n_epochs = 250
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNetModel(
    dim=(4, 128, 128), 
    num_channels=32, 
    num_res_blocks=1, 
    num_classes=None, 
    class_cond=False
).to(device)

optimizer = torch.optim.Adam(model.parameters())
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) 
node = NeuralODE(model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4).to(device)

start = time.time()

# Physiologically Inspired Loss

In [None]:
def laplacian_loss(pred_image, true_image):
    laplacian_filter = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]]).float().unsqueeze(0).unsqueeze(0).to(pred_image.device)
    laplacian_filter = laplacian_filter.repeat(pred_image.shape[1], 1, 1, 1)  # Повторяем для каждого канала
    pred_laplace = F.conv2d(pred_image, laplacian_filter, padding=1, groups=pred_image.shape[1])
    true_laplace = F.conv2d(true_image, laplacian_filter, padding=1, groups=true_image.shape[1])
    
    return F.mse_loss(pred_laplace, true_laplace)

def wasserstein_loss(eeg_image, ecog_image):
    samples_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.05)
    B, C, H, W = eeg_image.size()  # Batch, Channels, Height, Width
    eeg_image_flat = eeg_image.view(B, -1, C)  
    ecog_image_flat = ecog_image.view(B, -1, C)
    
    return samples_loss(eeg_image_flat, ecog_image_flat)

def hybrid_loss(model_output, true_image, eeg_image, alpha=0.5, beta=0.4, gamma=0.1):
    # MSE Loss (Mean Squared Error)
    mse_loss = F.mse_loss(model_output, true_image)
    # Wasserstein Loss
    wass_loss = wasserstein_loss(model_output, true_image)
    # Laplacian Loss
    laplace_loss = laplacian_loss(model_output, true_image)
    total_loss = alpha * wass_loss + beta * mse_loss + gamma * laplace_loss
    
    return total_loss

In [None]:
to_pil = ToPILImage()
index_to_display = 100
data_item = test_dataset[index_to_display]

eeg_image_path = data_item["eeg_image"]
eeg_image = to_pil(eeg_image_path)
eeg_image_tensor = transform(eeg_image).unsqueeze(0).to(device)

ecog_image_path = data_item["ecog_image"]
ecog_image = to_pil(ecog_image_path)
ecog_image_tensor = transform(ecog_image).unsqueeze(0).to(device)

In [None]:
for epoch in tqdm(range(n_epochs)):
    for i, data in enumerate(train_data_loader):
        optimizer.zero_grad()
        eeg_image = data["eeg_image"].to(device)
        ecog_image = data["ecog_image"].to(device)
        x0 = torch.randn_like(eeg_image)
        
        t, xt, ut = FM.sample_location_and_conditional_flow(x0, ecog_image)
        vt = model(t, xt)
        
        # Вычисление потерь
        loss = hybrid_loss(vt, ecog_image, eeg_image, alpha=0.5, beta=0.4, gamma=0.1)
        loss = loss.mean() 
        wandb.log({"EOT_CFM_loss": loss})
        loss.backward()
        optimizer.step()
        
        print(f"epoch: {epoch}, steps: {i}, loss: {loss.item():.4}", end="\r")

        if (epoch + 1) % 5 == 0:
            end = time.time()
            start = end
            node = NeuralODE(model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4).to(device)
            with torch.no_grad():
                traj = torchdiffeq.odeint(
                    lambda t, x: model.forward(t, x), 
                    eeg_image_tensor,
                    torch.linspace(0, 1, 2, device=device),
                    atol=1e-4,
                    rtol=1e-4,
                    method="dopri5",
                )
                eeg_image_display = ToPILImage()(eeg_image_tensor.squeeze(0).cpu())
                ecog_image_display = ToPILImage()(ecog_image_tensor.squeeze(0).cpu())

                generated_ecog = traj[-1, :1].view([-1, 4, 128, 128]).clip(-1, 1)
                grid = make_grid(generated_ecog, value_range=(-1, 1), padding=0, nrow=1)
                generated_ecog_display = ToPILImage()(grid)

                fig, axs = plt.subplots(1, 3, figsize=(15, 5))
                # Плотинг ЭЭГ
                axs[0].imshow(eeg_image_display)
                axs[0].set_title("Input EEG")
                axs[0].axis('off')
                # Плотинг ЭКоГ
                axs[1].imshow(ecog_image_display)
                axs[1].set_title("Target ECoG")
                axs[1].axis('off')
                # Плотинг сгенерированного ЭКоГ
                axs[2].imshow(generated_ecog_display)
                axs[2].set_title("Generated ECoG from EEG by EOTCFM")
                axs[2].axis('off')

    torch.save(model_SB.state_dict(), '/beegfs/home/ruslan.kalimullin/GenAI_course/Project/Weights/CFM_EOT_epoch=%d.pth' % (epoch))

wandb.finish()

In [None]:
USE_TORCH_DIFFEQ = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_directory = "/beegfs/home/ruslan.kalimullin/MScThesis/topo-maps"

to_pil = ToPILImage()
index_to_display = 100
data_item = test_dataset[index_to_display]

eeg_image_path = data_item["eeg_image"]
eeg_image = to_pil(eeg_image_path)
eeg_image_tensor = transform(eeg_image).unsqueeze(0).to(device)

ecog_image_path = data_item["ecog_image"]
ecog_image = to_pil(ecog_image_path)
ecog_image_tensor = transform(ecog_image).unsqueeze(0).to(device)


# Модель forward принимает тензор времени t, тензор входа (например, eeg_image) и условие (ecog_image)
with torch.no_grad():
    if USE_TORCH_DIFFEQ:
        # Пример траектории для преобразования ЭЭГ в ЭКоГ
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, x),  # Без передачи y
            eeg_image_tensor,
            torch.linspace(0, 1, 2, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    else:
        # Альтернатива без использования torchdiffeq
        traj = node.trajectory(
            eeg_image_tensor,  # Входные данные ЭЭГ
            t_span=torch.linspace(0, 1, 2, device=device),  # Временной интервал
        )

eeg_image_display = ToPILImage()(eeg_image_tensor.squeeze(0).cpu())
ecog_image_display = ToPILImage()(ecog_image_tensor.squeeze(0).cpu())

generated_ecog = traj[-1, :1].view([-1, 4, 128, 128]).clip(-1, 1)
grid = make_grid(generated_ecog, value_range=(-1, 1), padding=0, nrow=1)
generated_ecog_display = ToPILImage()(grid)

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(eeg_image_display)
axs[0].set_title("Input EEG")
axs[0].axis('off')
axs[1].imshow(ecog_image_display)
axs[1].set_title("Target ECoG")
axs[1].axis('off')
axs[2].imshow(generated_ecog_display)
axs[2].set_title("Generated ECoG from EEG")
axs[2].axis('off')
plt.show()

# Shrodinger Bridge Conditional Flow Matching

In [None]:
sigma = 0.1
n_epochs = 250
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_SB = UNetModel(
    dim=(4, 128, 128), 
    num_channels=32, 
    num_res_blocks=1, 
    num_classes=None, 
    class_cond=False
).to(device)

optimizer = torch.optim.Adam(model_SB.parameters())
FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma) 
node = NeuralODE(model_SB, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4).to(device)

start = time.time()

In [None]:
to_pil = ToPILImage()
index_to_display = 100
data_item = test_dataset[index_to_display]

eeg_image_path = data_item["eeg_image"]
eeg_image = to_pil(eeg_image_path)
eeg_image_tensor = transform(eeg_image).unsqueeze(0).to(device)

ecog_image_path = data_item["ecog_image"]
ecog_image = to_pil(ecog_image_path)
ecog_image_tensor = transform(ecog_image).unsqueeze(0).to(device)

In [None]:
for epoch in tqdm(range(n_epochs)):
    for i, data in enumerate(train_data_loader):
        optimizer.zero_grad()
        eeg_image = data["eeg_image"].to(device)
        ecog_image = data["ecog_image"].to(device)
        
        # Генерация случайного шума
        x0 = torch.randn_like(eeg_image)
        
        # Выборка и предсказание
        t, xt, ut = FM.sample_location_and_conditional_flow(x0, ecog_image)
        vt = model_SB(t, xt)
        
        # Вычисление потерь
        loss = hybrid_loss(vt, ecog_image, eeg_image, alpha=0.5, beta=0.4, gamma=0.1)
        loss = loss.mean() 
        wandb.log({"SB_CFM_loss": loss})
        loss.backward()
        optimizer.step()
        
        print(f"epoch: {epoch}, steps: {i}, loss: {loss.item():.4}", end="\r")

        if (epoch + 1) % 5 == 0:
            end = time.time()
            start = end
            node = NeuralODE(model_SB, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4).to(device)
            with torch.no_grad():
                traj = torchdiffeq.odeint(
                    lambda t, x: model_SB.forward(t, x), 
                    eeg_image_tensor,
                    torch.linspace(0, 1, 2, device=device),
                    atol=1e-4,
                    rtol=1e-4,
                    method="dopri5",
                )
                eeg_image_display = ToPILImage()(eeg_image_tensor.squeeze(0).cpu())
                ecog_image_display = ToPILImage()(ecog_image_tensor.squeeze(0).cpu())

                generated_ecog = traj[-1, :1].view([-1, 4, 128, 128]).clip(-1, 1)
                grid = make_grid(generated_ecog, value_range=(-1, 1), padding=0, nrow=1)
                generated_ecog_display = ToPILImage()(grid)

                fig, axs = plt.subplots(1, 3, figsize=(15, 5))
                # Плотинг ЭЭГ
                axs[0].imshow(eeg_image_display)
                axs[0].set_title("Input EEG")
                axs[0].axis('off')
                # Плотинг ЭКоГ
                axs[1].imshow(ecog_image_display)
                axs[1].set_title("Target ECoG")
                axs[1].axis('off')
                # Плотинг сгенерированного ЭКоГ
                axs[2].imshow(generated_ecog_display)
                axs[2].set_title("Generated ECoG from EEG by EOTCFM")
                axs[2].axis('off')

    torch.save(model_SB.state_dict(), '/beegfs/home/ruslan.kalimullin/GenAI_course/Project/Weights/CFM_SB_epoch=%d.pth' % (epoch))

wandb.finish()

In [None]:
USE_TORCH_DIFFEQ = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_directory = "/beegfs/home/ruslan.kalimullin/MScThesis/topo-maps"

to_pil = ToPILImage()
index_to_display = 100
data_item = test_dataset[index_to_display]

eeg_image_path = data_item["eeg_image"]
eeg_image = to_pil(eeg_image_path)
eeg_image_tensor = transform(eeg_image).unsqueeze(0).to(device)

ecog_image_path = data_item["ecog_image"]
ecog_image = to_pil(ecog_image_path)
ecog_image_tensor = transform(ecog_image).unsqueeze(0).to(device)


# Модель forward принимает тензор времени t, тензор входа (например, eeg_image) и условие (ecog_image)
with torch.no_grad():
    if USE_TORCH_DIFFEQ:
        # Пример траектории для преобразования ЭЭГ в ЭКоГ
        traj = torchdiffeq.odeint(
            lambda t, x: model_SB.forward(t, x),
            eeg_image_tensor,
            torch.linspace(0, 1, 2, device=device),
            atol=1e-4,
            rtol=1e-4,
            method="dopri5",
        )
    else:
        traj = node.trajectory(
            eeg_image_tensor,  # Входные данные ЭЭГ
            t_span=torch.linspace(0, 1, 2, device=device),  # Временной интервал
        )

eeg_image_display = ToPILImage()(eeg_image_tensor.squeeze(0).cpu())
ecog_image_display = ToPILImage()(ecog_image_tensor.squeeze(0).cpu())

generated_ecog = traj[-1, :1].view([-1, 4, 128, 128]).clip(-1, 1)
grid = make_grid(generated_ecog, value_range=(-1, 1), padding=0, nrow=1)
generated_ecog_display = ToPILImage()(grid)

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(eeg_image_display)
axs[0].set_title("Input EEG")
axs[0].axis('off')
axs[1].imshow(ecog_image_display)
axs[1].set_title("Target ECoG")
axs[1].axis('off')
axs[2].imshow(generated_ecog_display)
axs[2].set_title("Generated ECoG from EEG")
axs[2].axis('off')
plt.show()