In [1]:
import numpy as np
import torch
from torch import nn
import random

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

import os

from datetime import datetime

In [2]:
train_betas_path = '../data/processed_data/subj01/nsd_train_fmriavg_nsdgeneral_sub1.npy'
train_caps_embed_path = '../data/caps_embeds/train_caps_embeds_sub1.npy'
train_caps_embed_neg_path = '../data/caps_embeds/train_caps_embeds_negative_sub1.npy'

# train_betas = np.load(train_betas_path)
# train_caps_embed = np.load(train_caps_embed_path)
# train_caps_embed_neg = np.load(train_caps_embed_neg_path)

In [3]:
test_betas_path = '../data/caps_embeds/test_betas_split_sub1.npy'
test_caps_embed_path = '../data/caps_embeds/test_caps_embeds_split_sub1.npy'
test_caps_embed_neg_path = '../data/caps_embeds/test_caps_embeds_negative_split_sub1.npy'

# test_betas = np.load(test_betas_path)
# test_caps_embed = np.load(test_caps_embed_path)
# test_caps_embed_neg = np.load(test_caps_embed_neg_path)

In [4]:
val_betas_path = '../data/caps_embeds/val_betas_split_sub1.npy'
val_caps_embed_path = '../data/caps_embeds/val_caps_embeds_split_sub1.npy'
val_caps_embed_neg_path = '../data/caps_embeds/val_caps_embeds_split_negative_sub1.npy'

# val_betas = np.load(val_betas_path)
# val_caps_embed = np.load(val_caps_embed_path)
# val_caps_embed_neg = np.load(val_caps_embed_neg_path)

In [5]:
class CustomDataset(Dataset):
    def __init__(self, betas_path, embeds_path):
        self.betas = torch.from_numpy(np.load(betas_path)).float().cuda()
        self.embeds = torch.from_numpy(
                np.squeeze(np.load(embeds_path))
            ).float().cuda()

    def __len__(self):
        return len(self.betas)

    def __getitem__(self, index):
        return self.betas[index], self.embeds[index]


class CustomDataLoader:
    def __init__(self, dataset, batch_size):
        self.data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    def __iter__(self):
        return iter(self.data_loader)

    def __len__(self):
        return len(self.data_loader)

batch_size = 64

train_dataset = CustomDataset(train_betas_path,
                              train_caps_embed_path,)
train_data_loader = CustomDataLoader(train_dataset, batch_size)

test_dataset = CustomDataset(test_betas_path,
                             test_caps_embed_path,)
test_data_loader = CustomDataLoader(test_dataset, batch_size)

val_dataset = CustomDataset(val_betas_path,
                            val_caps_embed_path,)
val_data_loader = CustomDataLoader(val_dataset, batch_size)

for beta, embed in train_data_loader:
    print(f"beta      {beta.shape}")
    print(f"embed     {embed.shape}")
    break

beta      torch.Size([64, 15724])
embed     torch.Size([64, 1280])


In [6]:
input_dim = 15724
print('input_dim:', input_dim)
ouput_dim = 1280
print('output_dim:', ouput_dim)

input_dim: 15724
output_dim: 1280


In [7]:
class RegressionModel(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.linear_regression = nn.Linear(in_features=input_channels,
                               out_features=output_channels)
        self.layer1 = 2**10 * 8
        self.layer2 = 2**10 * 2

        self.mlp = nn.Sequential(
            nn.Linear(in_features=input_channels, 
                      out_features=self.layer1),
            nn.LeakyReLU(0.05, inplace=True),    
            nn.Linear(in_features=self.layer1, 
                      out_features=self.layer2),
            nn.LeakyReLU(0.05, inplace=True),    
            nn.Linear(in_features=self.layer2, 
                      out_features=output_channels),
        )


    def forward(self, x):
        return self.mlp(x)

model_embed = RegressionModel(input_dim, ouput_dim).cuda()

In [8]:
import torch.nn.functional as F

class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()

    def forward(self, y_true, y_pred):
        cosine_sim = F.cosine_similarity(y_true, y_pred, dim=1)
        loss = 1 - cosine_sim.mean()
        return loss

In [9]:
cos_loss = CosineSimilarityLoss()
mse_loss = nn.MSELoss(reduction='mean')
optim_embed = torch.optim.Adam(
    params=model_embed.parameters(), lr=0.0001, weight_decay=1)

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
run_name = 'mlp_one_model' #+ datetime.now().strftime("%Y-%m-%d_%H:%M")
writer = SummaryWriter(os.path.join('runs', run_name))

In [11]:
import torch
from tqdm import tqdm

num_epochs = 100
iter_index = 0
min_test_loss = 1e15
best_model_path = f'../models/{run_name}.pth'
test_iter_index = 0

for epoch in tqdm(range(num_epochs)):
    model_embed.train()
    for betas, embeds in train_data_loader:
        iter_index += 1
        pred_embeds = model_embed(betas)

        loss_embed = mse_loss(pred_embeds, embeds)
        optim_embed.zero_grad()
        loss_embed.backward()
        optim_embed.step()

        writer.add_scalars('train/losses', {
            'embed': loss_embed.item(),
        }, iter_index)
    
    test_loss = 0
    model_embed.eval()
    with torch.inference_mode():
        n_batches = 0
        for betas, embeds in val_data_loader:
            test_iter_index += 1
            pred_embeds = model_embed(betas)

            loss_embed = mse_loss(pred_embeds, embeds)

            cos_embed = cos_loss(pred_embeds, embeds)

            writer.add_scalars('val/losses', {
                'embed': loss_embed.item(),
            }, test_iter_index)

            writer.add_scalars('val/cos_sim', {
                'cos_embed': cos_embed.item(),
            }, test_iter_index)

            n_batches += 1
            test_loss += loss_embed.item()

    test_loss /= n_batches
    if test_loss < min_test_loss:
        min_test_loss = test_loss
        print(f'Model saved with loss: {test_loss}')
        torch.save({
            'model_embed_state_dict': model_embed.state_dict(),
            'epoch': epoch,
            'min_test_loss': min_test_loss,
        }, best_model_path)

  0%|          | 0/100 [00:00<?, ?it/s]

Model saved with loss: 2.7120538651943207


  1%|          | 1/100 [00:13<21:47, 13.20s/it]

Model saved with loss: 2.3417610228061676


  2%|▏         | 2/100 [00:26<21:23, 13.09s/it]

Model saved with loss: 2.032508000731468


  3%|▎         | 3/100 [00:39<21:06, 13.06s/it]

Model saved with loss: 1.812092900276184


  4%|▍         | 4/100 [00:52<20:51, 13.03s/it]

Model saved with loss: 1.6159464567899704


  5%|▌         | 5/100 [01:05<20:36, 13.02s/it]

Model saved with loss: 1.4572540074586868


  6%|▌         | 6/100 [01:18<20:22, 13.01s/it]

Model saved with loss: 1.338688462972641


  7%|▋         | 7/100 [01:31<20:12, 13.04s/it]

Model saved with loss: 1.2492763549089432


  8%|▊         | 8/100 [01:44<19:59, 13.03s/it]

Model saved with loss: 1.1621614694595337


  9%|▉         | 9/100 [01:57<19:44, 13.02s/it]

Model saved with loss: 1.109212189912796


 10%|█         | 10/100 [02:10<19:31, 13.02s/it]

Model saved with loss: 1.0597576647996902


 11%|█         | 11/100 [02:23<19:19, 13.03s/it]

Model saved with loss: 1.018299974501133


 12%|█▏        | 12/100 [02:36<19:09, 13.06s/it]

Model saved with loss: 0.9905326217412949


 13%|█▎        | 13/100 [02:49<18:56, 13.07s/it]

Model saved with loss: 0.9645259454846382


 14%|█▍        | 14/100 [03:02<18:42, 13.05s/it]

Model saved with loss: 0.9402236863970757


 15%|█▌        | 15/100 [03:15<18:29, 13.05s/it]

Model saved with loss: 0.9317614734172821


 16%|█▌        | 16/100 [03:28<18:15, 13.04s/it]

Model saved with loss: 0.90281593054533


 17%|█▋        | 17/100 [03:41<18:00, 13.02s/it]

Model saved with loss: 0.8877482712268829


 18%|█▊        | 18/100 [03:54<17:48, 13.03s/it]

Model saved with loss: 0.8816548511385918


 20%|██        | 20/100 [04:19<16:45, 12.57s/it]

Model saved with loss: 0.8652753531932831


 21%|██        | 21/100 [04:32<16:42, 12.69s/it]

Model saved with loss: 0.8562572300434113


 22%|██▏       | 22/100 [04:45<16:37, 12.78s/it]

Model saved with loss: 0.8551634326577187


 23%|██▎       | 23/100 [04:58<16:29, 12.86s/it]

Model saved with loss: 0.8513344973325729


 24%|██▍       | 24/100 [05:11<16:21, 12.91s/it]

Model saved with loss: 0.8423099964857101


 25%|██▌       | 25/100 [05:24<16:12, 12.96s/it]

Model saved with loss: 0.839881032705307


 26%|██▌       | 26/100 [05:37<16:01, 12.99s/it]

Model saved with loss: 0.835456408560276


 27%|██▋       | 27/100 [05:50<15:50, 13.01s/it]

Model saved with loss: 0.8287530690431595


 28%|██▊       | 28/100 [06:03<15:37, 13.02s/it]

Model saved with loss: 0.8286100998520851


 29%|██▉       | 29/100 [06:16<15:25, 13.03s/it]

Model saved with loss: 0.8251029998064041


 30%|███       | 30/100 [06:29<15:12, 13.04s/it]

Model saved with loss: 0.8189946115016937


 32%|███▏      | 32/100 [06:54<14:15, 12.58s/it]

Model saved with loss: 0.816073901951313


 33%|███▎      | 33/100 [07:07<14:12, 12.73s/it]

Model saved with loss: 0.812539592385292


 34%|███▍      | 34/100 [07:20<14:06, 12.83s/it]

Model saved with loss: 0.807425431907177


 35%|███▌      | 35/100 [07:33<13:57, 12.89s/it]

Model saved with loss: 0.8071509301662445


 36%|███▌      | 36/100 [07:46<13:48, 12.94s/it]

Model saved with loss: 0.7994620427489281


 40%|████      | 40/100 [08:33<12:00, 12.01s/it]

Model saved with loss: 0.7966928258538246


 41%|████      | 41/100 [08:47<12:07, 12.32s/it]

Model saved with loss: 0.7964910864830017


100%|██████████| 100/100 [20:07<00:00, 12.08s/it]


In [12]:
with torch.inference_mode():
    n_batches = 0
    for betas, embeds in test_data_loader:
        test_iter_index += 1
        pred_embeds = model_embed(betas)

        loss_embed = mse_loss(pred_embeds, embeds)

        cos_embed = cos_loss(pred_embeds, embeds)

        writer.add_scalars('test/losses', {
            'embed': loss_embed.item(),
        }, test_iter_index)

        writer.add_scalars('test/cos_sim', {
            'cos_embed': cos_embed.item(),
        }, test_iter_index)

        n_batches += 1
        test_loss += loss_embed.item()