## Imports

In [1]:
from io import BytesIO
import os
from contextlib import nullcontext
import glob

import fire
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from torch import autocast
from torchvision import transforms
import requests
import pandas as pd

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config

In [2]:
import re
import pandas as pd
import matplotlib.pyplot as plt
import math
import numpy as np
from skimage import io
import torchvision.transforms as T
from typing import List, Optional, Union

from scipy.signal import savgol_filter
from six.moves import xrange
import umap
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.data import Dataset

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

from ipynb.fs.full.WaveSeparatorCode import *

install cvxpy for L1 norm minimization for PeriodStrength fun (Ramanujan methods)


## Stable-Diffusion Functions

In [3]:
def load_model_from_config(config, ckpt, device, verbose=False):
    print(f"Loading model from {ckpt}")

    pl_sd = torch.load(ckpt, map_location=device)
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.to(device)
    model.eval()
    return model

In [4]:
def load_im(im_path):
    if im_path.startswith("http"):
        response = requests.get(im_path)
        response.raise_for_status()
        im = Image.open(BytesIO(response.content))
    else:
        im = Image.open(im_path).convert("RGB")
    tforms = transforms.Compose([
        # transforms.Resize(224),
        # transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
    ])
    inp = tforms(im).unsqueeze(0)
    return inp*2-1

def prep_im(img):
    
    tforms = transforms.Compose([
        # transforms.Resize(224),
        # transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
    ])
    inp = tforms(img).unsqueeze(0)
    return inp*2-1

In [5]:
@torch.no_grad()
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta):
    precision_scope = autocast if precision=="autocast" else nullcontext
    with precision_scope("cuda"):
        with model.ema_scope():
            c = model.get_learned_conditioning(input_im).tile(n_samples,1,1)

            if scale != 1.0:
                uc = torch.zeros_like(c)
            else:
                uc = None

            shape = [4, h // 8, w // 8]
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples,
                                             shape=shape,
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc,
                                             eta=ddim_eta,
                                             x_T=None)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

In [6]:

def main(
    im_path="data/example_conditioning/superresolution/sample_0.jpg",
    ckpt="models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt",
    config="configs/stable-diffusion/sd-image-condition-finetune.yaml",
    outpath="im_variations",
    scale=3.0,
    h=512,
    w=512,
    n_samples=4,
    precision="fp32",
    plms=True,
    ddim_steps=50,
    ddim_eta=0.0,
    device_idx=0,
    save=True,
    eval=True,
    ):

    device = f"cuda:{device_idx}"
    config = OmegaConf.load(config)
    model = load_model_from_config(config, ckpt, device=device)

    if plms:
        sampler = PLMSSampler(model)
        ddim_eta = 0.0
    else:
        sampler = DDIMSampler(model)

    os.makedirs(outpath, exist_ok=True)

    sample_path = os.path.join(outpath, "samples")
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))

    if isinstance(im_path, str):
        im_paths = glob.glob(im_path)
    im_paths = sorted(im_paths)

    all_similarities = []

    for im in im_paths:
        input_im = load_im(im).to(device)

        x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta)
        if save:
            for x_sample in x_samples_ddim:
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                filename = os.path.join(sample_path, f"{base_count:05}.png")
                Image.fromarray(x_sample.astype(np.uint8)).save(filename)
                base_count += 1

        if eval:
            generated_embed = model.get_learned_conditioning(x_samples_ddim).squeeze(1)
            prompt_embed = model.get_learned_conditioning(input_im).squeeze(1)

            generated_embed /= generated_embed.norm(dim=-1, keepdim=True)
            prompt_embed /= prompt_embed.norm(dim=-1, keepdim=True)
            similarity = prompt_embed @ generated_embed.T
            mean_sim = similarity.mean()
            all_similarities.append(mean_sim.unsqueeze(0))

    df = pd.DataFrame(zip(im_paths, [x.item() for x in all_similarities]), columns=["filename", "similarity"])
    df.to_csv(os.path.join(sample_path, "eval.csv"))
    print(torch.cat(all_similarities).mean())



In [7]:

if __name__ == "__main__":
    fire.Fire(main)

FileNotFoundError: [Errno 2] No such file or directory: 'configs/stable-diffusion/sd-image-condition-finetune.yaml'

In [9]:
sd_root = r"C:\Users\bruno\OneDrive\Desktop\BrainReader RESEARCH\Code\external_gits\StableDiffusion_img2img\stable-diffusion"

In [16]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Found {torch.cuda.device_count()} available GPU(s)")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("No GPU available")

Found 1 available GPU(s)
GPU 0: NVIDIA GeForce GTX 1650


In [10]:
device_idx=0
device = f"cuda:{device_idx}"
ckpt="models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt"
ckpt = os.path.join(sd_root, ckpt)
config="configs/stable-diffusion/sd-image-condition-finetune.yaml"
model = load_model_from_config(config, ckpt, device=device)
image = next(iter(train_loader))[0]
c = model.get_learned_conditioning(image)

Loading model from C:\Users\bruno\OneDrive\Desktop\BrainReader RESEARCH\Code\external_gits\StableDiffusion_img2img\stable-diffusion\models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 3.39 GiB already allocated; 0 bytes free; 3.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Loading the DataSet

In [7]:
root_dir = (r"C:\Users\bruno\OneDrive\Desktop\Psychology\For Thesis Project\Datasets\My_ImageNet")

In [16]:
class ImportImagenet(Dataset):
    def __init__ (self, root_dir, mode, model):
        self.root_dir = root_dir
        self.device = device
        downed_image= None
        if mode == 'Test':
            fileid = 'Test_list.csv'
        elif mode == 'Val':
            fileid = 'Val_list.csv'
        elif mode == 'Train':
            fileid = 'Train_list.csv'
        self.csv = pd.read_csv(os.path.join(self.root_dir,fileid), delimiter= ',')       
        self.targets = torch.cuda.IntTensor(self.csv["Target"]).long()
        

    def __len__(self):
        return len(self.csv)

    def __getitem__(self,index):

        CLIP = False

        img_path = os.path.join(self.root_dir, self.csv.iloc[index,0])
        
        #This is used when NOT preprocessing images for CLIP
        if(not CLIP):

            image = io.imread(img_path)
            image = T.ToTensor()(image)
            image = T.CenterCrop((200, 200))(image)
            image = T.Resize(192)(image)
            # y_label = self.targets[index] 

            return(image.to(device)) 

        else:

            with torch.no_grad():

                image_token = model.get_learned_conditioning(image)
                # image_token = image_token.squeeze(1)
        

            return(image.to(device), image_token.to(device))

        

         
       

#### Splitting the DataSet

In [9]:
batch = 5
subset = True
subset_cut = 10
shuffle = True

device_idx=0
device = f"cuda:{device_idx}"
ckpt="models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt"
ckpt = os.path.join(sd_root, ckpt)
config="configs/stable-diffusion/sd-image-condition-finetune.yaml"
model = load_model_from_config(config, ckpt, device=device)

data_set = ImportImagenet(root_dir, 'Train', model)
dataset_len= len(data_set)

if (subset == True):

   
    # Subset dataset
    data_set = torch.utils.data.Subset(data_set, np.arange(0,len(data_set),subset_cut))
    
    dataset_len = len(data_set)
    
    val_split = int(0.1 * dataset_len)
    test_split = int(0.1 * dataset_len)
    train_split = dataset_len - val_split - test_split

else:
    
    val_split = int(0.1 * dataset_len)
    test_split = int(0.1 * dataset_len)
    train_split = dataset_len - val_split - test_split



train_set, val_set, test_set = torch.utils.data.random_split(data_set, [train_split, val_split, test_split])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch , shuffle = shuffle)
val_loader = torch.utils.data.DataLoader(train_set, batch_size=batch , shuffle = shuffle)
test_loader = torch.utils.data.DataLoader(train_set, batch_size=batch , shuffle = shuffle)

NameError: name 'device' is not defined

In [None]:
#Visualize images (they appear distorted when processed for clip)

images = next(iter(train_loader))

for image in images[0].squeeze(dim=1):
   
    plt.imshow(T.ToPILImage()(image.squeeze()))
    plt.figure()

## Simulating the electrodes

In [None]:
# Define the model for a single electrode (linear regression)

class LinearRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)
        
    def forward(self, x):
        return self.linear(x)

In [None]:
# Creating 124 electrodes

num_electrodes = 124
input_size = 150528 # Flattened image size
output_size = 1
electrodes = nn.ModuleList([LinearRegression(input_size, output_size).to(device) for _ in range(num_electrodes)])

# Pass the image through each electrode and collect the outputs

def readBrain(images, electrodes = electrodes):

    with torch.no_grad():
        brain_data = []
        for image in images: 
            electrode_data = []
            for electrode in electrodes:
                electrode_data.append(electrode(image.flatten()))

            # Concatenate the outputs into a single tensor
            electrodes_data = torch.cat(electrode_data, dim=0)

            brain_data.append(electrodes_data)
          

        brain_data = torch.stack(brain_data, dim=0)

        n_brain_data = brain_data + (0.1**0.5)*torch.randn(brain_data.shape).to(device)
    
        return n_brain_data

## Reconstructing Latents from simulated electrodes

In [None]:
reconstructor = LinearRegression(124, 512).to(device)

In [None]:
class RunManager():

    def __init__(self, max_epoch, learning_rate = 0.01, network = reconstructor):

        self.epoch_count = 0
        self.best_vloss = None
        self.bad_validation_counter = 0
        self.val_stop = False
        self.saved_parameters1 = None
        self.max_epoch = max_epoch
        self.epoch_stop = False
        self.train_losses = []
        self.val_losses = []
        self.loss_function = nn.MSELoss(reduction='mean')
        self.optimizer = optim.Adam(reconstructor.parameters(), lr=learning_rate)
        self.network = network
    
    def weights_init(self):
        for layer in self.network.children():
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()


    def val_early_stopping(self,  vloss, network, patience):
    

        if self.best_vloss == None:
            self.best_vloss = vloss

        elif self.best_vloss > vloss:      #if best_vloss > vloss  
            self.bad_validation_counter = 0
            self.best_vloss = vloss
            
        elif self.best_vloss <= vloss:      #if best_vloss <= vloss
            self.bad_validation_counter += 1
            
            if self.bad_validation_counter == 1:
                self.saved_parameters1 = network.state_dict() 
                        
# Stop if validation performance does not improve for patience number of epochs
        if self.bad_validation_counter >= patience:
            self.val_stop = True
            print("Val stop")

    def train_track_loss(self, loss, batch):
        self.train_epoch_loss += loss.item()
        
    def val_track_loss(self, loss, batch):
        self.val_epoch_loss += loss.item()

    def test_track_loss(self, loss, batch):
        self.test_epoch_loss += loss.item()

    def begin_epoch(self):  
        
        self.epoch_count += 1
        self.train_epoch_loss = 0
        self.test_epoch_loss = 0
        self.val_epoch_loss = 0

    # Stop if max_epoch is reached
        if self.max_epoch != None:
            if self.epoch_count == self.max_epoch:
                self.epoch_stop = True


    def networkStep(self, network, loader, mode):

        n_batches = 0

        for batch in loader:

                n_batches = n_batches + 1

                images, tokens = batch
                tokens = tokens.squeeze(1)
                brain_data = readBrain(images)
                if(mode == "Train"):
                    self.optimizer.zero_grad();
                output = network(brain_data)
                loss = self.loss_function(output, tokens)

                if(mode == "Train"):
             
                    loss.backward();
                    self.optimizer.step();

                    self.train_track_loss(loss,batch)

                elif(mode == "Val"):
                  
                    self.val_track_loss(loss,batch)
                else:
                    self.test_track_loss(loss,batch)

        return n_batches

    def train_one_epoch(self, network, train_loader):

        network.train()
        last_loss = 0         
        n_batches = self.networkStep(network, train_loader, "Train")                  
        last_loss = self.train_epoch_loss / n_batches 
        
        return last_loss

    def val_one_epoch(self, network, val_loader):
        with torch.no_grad():

            network.eval()
            last_vloss = 0
            n_batches = self.networkStep(network, val_loader, "Val")     
            last_vloss = self.val_epoch_loss / n_batches 
    
        return last_vloss

    def test_run(self, network, test_loader):
        self.begin_epoch()
        with torch.no_grad():

            network.eval()
            last_tloss = 0
            n_batches = self.networkStep(network, test_loader, "Test")
            last_tloss = self.test_epoch_loss / n_batches 

        return last_tloss
    
    
    def train_val_run(self, network, train_loader, val_loader, val_patience = 3):
        
        while self.val_stop == False and self.epoch_stop == False:

            print("Epoch: ", self.epoch_count+1)
            
            self.begin_epoch()
            
            last_loss = self.train_one_epoch(network, train_loader)
            self.train_losses.append(last_loss)
            last_vloss = self.val_one_epoch(network, val_loader)
            self.val_losses.append(last_vloss)
            self.val_early_stopping(last_vloss, network, val_patience)

#### Training

In [None]:
epochs = 1000
run_manager = RunManager(epochs)
run_manager.weights_init()

run_manager.train_val_run(reconstructor, train_loader, val_loader)

t_loss = run_manager.test_run(reconstructor, test_loader)



In [None]:
train_loss = run_manager.train_losses
val_loss = run_manager.val_losses

In [None]:
# torch.save(reconstructor.state_dict(), r"C:\Users\bruno\OneDrive\Desktop\BrainReader RESEARCH\Code\Project_git\Parameters_training_1")

#### Plotting Performance

In [None]:
plt.plot(train_loss, label="Training Loss")
plt.plot(val_loss, label="Validation Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
untrained_reconstructor = LinearRegression(124, 512).to(device)
unt_loss = run_manager.test_run(untrained_reconstructor, test_loader)

In [None]:
print("Loss before training: ", unt_loss, "Loss after training: ", t_loss)

## Decoding from latents

In [None]:
# Cloud GPU needed

outpath="im_variations"
scale=3.0
h=512
w=512
n_samples=4
precision="fp32"
plms=True
ddim_steps=50
ddim_eta=0.0
device_idx=0
save=True
eval=True

In [None]:
outpath = r"C:\Users\bruno\OneDrive\Desktop\BrainReader RESEARCH\Outputs\Trial1"

if plms:
    sampler = PLMSSampler(model)
    ddim_eta = 0.0
else:
    sampler = DDIMSampler(model)

os.makedirs(outpath, exist_ok=True)

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))


batch = next(iter(train_loader))


for img in batch:

    img = prep_im(img)

    x_samples_ddim = sample_model(img, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta)
  
    for x_sample in x_samples_ddim:
        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
        filename = os.path.join(sample_path, f"{base_count:05}.png")
        Image.fromarray(x_sample.astype(np.uint8)).save(filename)
        base_count += 1
