In [None]:
# use colab
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import os
%cd '/content/drive/My Drive/MindEye'
os.chdir('/content/drive/My Drive/MindEye')

In [None]:
!pip install open_clip_torch dalle2-pytorch wandb
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import numpy as np
import pandas as pd
import csv
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from utils import load_image, save_image, encode_img, decode_img, to_PIL
from model import Clipper, BrainNetwork, DiffusionPriorNetwork, BrainDiffusionPrior, OpenClipper
import clip


In [None]:
import wandb
!wandb login 72af0b7211cd6c1e899bf475205c6a9df94d43d3
wandb.login()


In [None]:
dataset_path = '../2023-Machine-Learning-Dataset/'
training_path = dataset_path + 'subj0{}/training_split/'
training_fmri_path = training_path + 'training_fmri/'
training_images_path = training_path + 'training_images/'
testing_path = dataset_path + 'subj0{}/test_split/test_fmri/'
image_infos_path = dataset_path + 'image_infos/subj0{}_infos_train.csv'

In [None]:
class MyDataset(Dataset):
  def __init__(self, fmri_data, images_folder, transform=None):
    self.fmri_data = fmri_data
    self.images_folder = images_folder
    self.image_paths = [f"{images_folder}/{filename}" for filename in os.listdir(images_folder)]
    self.transform = transform

  def __len__(self):
    return len(self.fmri_data)

  def __getitem__(self, idx):
    fmri = self.fmri_data[idx]
    image_path = self.image_paths[idx]
    image = load_image(image_path)

    if(self.transform):
      image = self.transform(image)

    return fmri, image

In [None]:
def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False,
              accelerator=None, local_rank=None, bidirectional=True):
    brain_clip = (preds @ targs.T)/temp

    if perm is not None and betas is not None and select is not None:
        probs = torch.diag(betas)
        probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas

        loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()
        if bidirectional:
            loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean()
            loss = (loss + loss2)/2
        return loss
    else:
        loss =  F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
        if bidirectional:
            loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
            loss = (loss + loss2)/2
        return loss
def mixco(voxels, beta=0.15, s_thresh=0.5):
    perm = torch.randperm(voxels.shape[0])
    voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype)
    betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype)
    select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device)
    betas_shape = [-1] + [1]*(len(voxels.shape)-1)
    voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \
        voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape)
    betas[~select] = 1
    return voxels, perm, betas, select

In [None]:
transform = transforms.Resize([512, 512])

# Load dataset, now only subj01
lh = np.load(training_path.format(1) + 'training_fmri/lh_training_fmri.npy')
rh = np.load(training_path.format(1) + 'training_fmri/rh_training_fmri.npy')
lrh = np.concatenate((lh, rh), axis=1)

my_dataset = MyDataset(lrh, training_images_path.format(1), transform=transform)

In [None]:
# pip install open_clip_torch

import open_clip
from PIL import Image
from torchvision import transforms

# model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
clip_extractor = OpenClipper(device=torch.device(device)).to(device)

In [None]:
clip_size = 768
out_dim = clip_size #257*clip_size
depth = 6
dim_head = 64
timesteps = 100
heads = clip_size//64
num_epochs = 120
epoch = 0
losses = []
val_losses = []
lrs = []

In [None]:
batch_size = 16
num_epochs = 120
num_train = 5000
lr_scheduler = 'cycle'
initial_lr = 1e-3
max_lr = 5e-4
random_seed = 42
train_size = 0.8
valid_size = 1 - train_size
num_workers = torch.cuda.device_count() if torch.cuda.device_count()>0 else 1
prior_mult = .03

In [None]:
voxel2clip_kwargs = dict(out_dim=out_dim,clip_size=clip_size,use_projector=True)
voxel2clip = BrainNetwork(**voxel2clip_kwargs).to(device)

In [None]:
prior_network = DiffusionPriorNetwork(
            dim=out_dim,
            depth=depth,
            dim_head=dim_head,
            heads=heads,
            causal=False,
            num_tokens = 257,
            learned_query_mode="pos_emb"
        ).to(device)
diffusion_prior = BrainDiffusionPrior(
        net=prior_network,
        image_embed_dim=out_dim,
        condition_on_text_encodings=False,
        timesteps=timesteps,
        cond_drop_prob=0.2,
        image_embed_scale=None,
        voxel2clip=voxel2clip,
    ).to(device)

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="clipTrain",

    # track hyperparameters and run metadata
    config={
        "learning_rate": initial_lr,
        "architecture": "MLP",
        "dataset": "NSD",
        "epochs": num_epochs,
        "random_seed": random_seed,
        "train_size": train_size,
        "valid_size": valid_size
    }
)

In [None]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
    {'params': [p for n, p in diffusion_prior.net.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in diffusion_prior.net.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in diffusion_prior.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in diffusion_prior.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [None]:
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=initial_lr)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                          total_steps=num_epochs*((num_train//batch_size)//num_workers),
                          final_div_factor=1000,
                          last_epoch=-1, pct_start=2/num_epochs)

In [None]:
# train-val split
generator = torch.Generator().manual_seed(random_seed)
trainset, validset = random_split(my_dataset, [train_size, valid_size], generator=generator)

In [None]:
# build dataloader
train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
# Load model

# checkpoint = torch.load("./ModelsClip/31", map_location=device)
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# diffusion_prior.load_state_dict(checkpoint['model_state_dict'])
# del checkpoint

In [None]:
progress_bar = tqdm(range(epoch, num_epochs), ncols=150)

In [None]:
testImg = Image.open(f"{training_images_path.format(1)}/0.png")

In [None]:
# for train_i, data in enumerate(train_dataloader):
#   voxels, images = data
#   optimizer.zero_grad()
#   print(voxels.shape,images.shape)
#   voxels = voxels.to(device).float()
#   images = images.to(device).float()
#   clip_voxels, clip_voxels_proj = diffusion_prior.voxel2clip(voxels)
#   print(clip_voxels.shape,clip_voxels_proj.shape)
#   clip_voxels.view(len(voxels),-1,768)
#   print(clip_extractor.embed_image(transforms.functional.to_pil_image(images[0].squeeze(0))).shape)
#   break

In [None]:
# from matplotlib import pyplot as plt
# plt.imshow(np.transpose(clip_extractor_cpu.preprocess(testImg).numpy(),(1,2,0)))
# # plt.imshow(np.transpose(testImg.squeeze(0),(1,2,0)))
# plt.show()

In [None]:
for epoch in progress_bar:
    diffusion_prior.train()

    loss_sum = 0
    val_loss_sum = 0

    reconst_fails = []
    for train_i, data in enumerate(train_dataloader):
        voxels, images = data
        optimizer.zero_grad()
        voxels = voxels.to(device).float()
        # images = images.to(device).float()
        # transforms.functional.to_pil_image(image.squeeze(0))
        # print(images)
        print(clip_extractor.embed_image(image[0].to(device)).to(device))
        clip_target = torch.cat([clip_extractor.embed_image(image.to(device)).to(device) for image in images])
        clip_voxels, clip_voxels_proj = diffusion_prior.voxel2clip(voxels)

        # calulate loss


        clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1).float()
        clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1).float()

        voxel, perm, betas, select = mixco(voxels)
        loss = mixco_nce(clip_voxels_norm,clip_target_norm,temp=.006,perm=perm, betas=betas, select=select)
        # print(clip_voxels_norm,clip_target_norm.dtype)

        # loss = F.mse_loss(clip_voxels_norm, clip_target_norm)

        loss_sum += loss.item()
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])


        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        logs = {
            "train/loss": np.mean(losses[-(train_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": len(losses),
            "train/loss_mse": loss_sum / (train_i + 1)
        }
        wandb.log(logs)

        progress_bar.set_postfix(**logs)
    torch.save({
      'epoch': epoch,
      'model_state_dict': diffusion_prior.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
      }, './ModelsClip/{}'.format(epoch)
    )
    diffusion_prior.eval()
    for val_i, data in enumerate(val_dataloader):
        voxels, images = data
        voxels = voxels.to(device).float()
        images = images.to(device).float()

        # encoded_latents = torch.cat([clip_extractor.embed_image(torch.squeeze(image,0)).to(device) for image in images])
        # clip_voxels, clip_voxels_proj = diffusion_prior.voxel2clip(voxels)
        clip_target = torch.cat([clip_extractor.embed_image(transforms.functional.to_pil_image(image.squeeze(0))).to(device) for image in images])
        clip_voxels, clip_voxels_proj = diffusion_prior.voxel2clip(voxels)
        clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1).float()
        clip_voxels_norm = nn.functional.normalize(clip_voxels_proj.flatten(1), dim=-1).float()
        # calulate loss
        # loss = F.mse_loss(clip_voxels, encoded_latents)
        voxel, perm, betas, select = mixco(voxels)
        loss = mixco_nce(clip_voxels_norm,clip_target_norm,temp=.006,perm=perm, betas=betas, select=select)

        val_loss_sum += loss.item()
        val_losses.append(loss.item())

    logs = {
        "train/loss": np.mean(losses[-(train_i+1):]),
        "val/loss": np.mean(val_losses[-(val_i+1):]),
        "train/lr": lrs[-1],
        "train/num_steps": len(losses),
        "train/loss_mse": loss_sum / (train_i + 1),
        "val/loss_mse": val_loss_sum / (val_i + 1)
    }
    wandb.log(logs)
