In [1]:
import numpy as np
import torch
from torch import nn
import random

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

import os

from datetime import datetime

In [2]:
train_betas_path = '../data/processed_data/subj01/nsd_train_fmriavg_nsdgeneral_sub1.npy'
train_caps_embed_path = '../data/caps_embeds/train_caps_embeds_sub1.npy'
train_caps_embed_neg_path = '../data/caps_embeds/train_caps_embeds_negative_sub1.npy'

train_betas = np.load(train_betas_path)
# train_caps_embed = np.load(train_caps_embed_path)
# train_caps_embed_neg = np.load(train_caps_embed_neg_path)

In [3]:
train_betas.shape

(8859, 15724)

In [4]:
test_betas_path = '../data/processed_data/subj01/nsd_test_fmriavg_nsdgeneral_sub1.npy'
test_caps_embed_path = '../data/caps_embeds/test_caps_embeds_sub1.npy'
test_caps_embed_neg_path = '../data/caps_embeds/test_caps_embeds_negative_sub1.npy'

# test_betas = np.load(test_betas_path)
# test_caps_embed = np.load(test_caps_embed_path)
# test_caps_embed_neg = np.load(test_caps_embed_neg_path)

In [5]:
class CustomDataset(Dataset):
    def __init__(self, betas_path, embeds_path, neg_embeds_path):
        self.betas = torch.from_numpy(np.load(betas_path)).float().cuda()
        self.embeds = torch.from_numpy(
            np.squeeze(np.load(embeds_path))
            ).float().cuda()
        self.embeds_neg = torch.from_numpy(
            np.squeeze(np.load(neg_embeds_path))
            ).float().cuda()

    def __len__(self):
        return len(self.betas)

    def __getitem__(self, index):
        return self.betas[index], self.embeds[index], self.embeds_neg[index]


class CustomDataLoader:
    def __init__(self, dataset, batch_size):
        self.data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    def __iter__(self):
        return iter(self.data_loader)

    def __len__(self):
        return len(self.data_loader)

batch_size = 64

train_dataset = CustomDataset(train_betas_path,
                        train_caps_embed_path,
                        train_caps_embed_neg_path)
train_data_loader = CustomDataLoader(train_dataset, batch_size)

test_dataset = CustomDataset(test_betas_path,
                              test_caps_embed_path,
                              test_caps_embed_neg_path)
test_data_loader = CustomDataLoader(test_dataset, batch_size)

for beta, embed, neg_embed in train_data_loader:
    print(f"beta      {beta.shape}")
    print(f"embed     {embed.shape}")
    print(f"neg_embed {neg_embed.shape}")
    break

beta      torch.Size([64, 15724])
embed     torch.Size([64, 1280])
neg_embed torch.Size([64, 1280])


In [6]:
input_dim = 15724
print('input_dim:', input_dim)
ouput_dim = 1280
print('output_dim:', ouput_dim)

input_dim: 15724
output_dim: 1280


In [7]:
class RegressionModel(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.linear_regression = nn.Linear(in_features=input_channels,
                               out_features=output_channels)
        self.layer1 = 2**10 * 4
        self.layer2 = 2**10 * 5

        self.mlp = nn.Sequential(
            nn.Linear(in_features=input_channels, 
                      out_features=self.layer1),
            nn.LeakyReLU(0.05, inplace=True),    
            nn.Linear(in_features=self.layer1, 
                      out_features=output_channels),
        )


    def forward(self, x):
        return self.mlp(x)

model_embed = RegressionModel(input_dim, ouput_dim).cuda()
model_neg = RegressionModel(input_dim, ouput_dim).cuda()

In [8]:
next(model_embed.parameters()).dtype

torch.float32

In [9]:
import torch.nn.functional as F

class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()

    def forward(self, y_true, y_pred):
        cosine_sim = F.cosine_similarity(y_true, y_pred, dim=1)
        loss = 1 - cosine_sim.mean()
        return loss

In [10]:
cos_loss = CosineSimilarityLoss()
mse_loss = nn.MSELoss(reduction='mean')
optim_embed = torch.optim.Adam(params=model_embed.parameters(), lr=0.0001, weight_decay=1)
optim_neg = torch.optim.Adam(params=model_neg.parameters(), lr=0.0001, weight_decay=1)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
run_name = 'mlp_only_cos' #+ datetime.now().strftime("%Y-%m-%d_%H:%M")
writer = SummaryWriter(os.path.join('runs', run_name))

In [12]:
import torch
from tqdm import tqdm

num_epochs = 100
iter_index = 0
min_test_loss = 1e15
best_model_path = f'../models/{run_name}.pth'
test_iter_index = 0

for epoch in tqdm(range(num_epochs)):
    model_embed.train()
    model_neg.train()

    for betas, embeds, neg_embeds in train_data_loader:
        iter_index += 1
        pred_embeds = model_embed(betas)
        pred_neg_embeds = model_neg(betas)

        loss_embed = mse_loss(pred_embeds, embeds)
        optim_embed.zero_grad()
        loss_embed.backward()
        optim_embed.step()
        
        loss_neg = mse_loss(pred_neg_embeds, neg_embeds)
        optim_neg.zero_grad()
        loss_neg.backward()
        optim_neg.step()

        writer.add_scalars('train/losses', {
            'embed': loss_embed.item(),
            'negat': loss_neg.item(),
        }, iter_index)

    test_loss = 0
    
    model_embed.eval()
    model_neg.eval()

    # if epoch % 10 == 0:
    with torch.inference_mode():
        n_batches = 0
        for betas, embeds, neg_embeds in test_data_loader:
            test_iter_index += 1
            pred_embeds = model_embed(betas)
            pred_neg_embeds = model_neg(betas)

            # loss_embed = mse_loss(pred_embeds, embeds)
            # loss_neg = mse_loss(pred_neg_embeds, neg_embeds)

            cos_embed = cos_loss(pred_embeds, embeds)
            cos_neg = cos_loss(pred_neg_embeds, neg_embeds)

            writer.add_scalars('test/losses', {
                'embed': loss_embed.item(),
                'negat': loss_neg.item(),
            }, test_iter_index)

            writer.add_scalars('test/cos_sim', {
                'cos_embed': cos_embed.item(),
                'cos_neg': cos_neg.item(),
            }, test_iter_index)

            n_batches += 1
            test_loss += loss_embed.item() + loss_neg.item()
    test_loss /= n_batches
    if test_loss < min_test_loss:
        min_test_loss = test_loss
        print(f'Model saved with loss: {test_loss}')
        torch.save({
            'model_embed_state_dict': model_embed.state_dict(),
            'model_neg_state_dict': model_neg.state_dict(),
            'epoch': epoch,
            'min_test_loss': min_test_loss,
        }, best_model_path)


  0%|          | 0/100 [00:00<?, ?it/s]

Model saved with loss: 52.54768180847168


  1%|          | 1/100 [00:12<20:40, 12.53s/it]

Model saved with loss: 43.72822570800781


  2%|▏         | 2/100 [00:25<20:38, 12.64s/it]

Model saved with loss: 42.85312461853027


  3%|▎         | 3/100 [00:37<20:27, 12.65s/it]

Model saved with loss: 34.55750846862793


  4%|▍         | 4/100 [00:50<20:13, 12.65s/it]

Model saved with loss: 32.06180763244629


  5%|▌         | 5/100 [01:03<20:01, 12.64s/it]

Model saved with loss: 25.752647399902344


  6%|▌         | 6/100 [01:15<19:48, 12.64s/it]

Model saved with loss: 23.92900848388672


  7%|▋         | 7/100 [01:28<19:40, 12.70s/it]

Model saved with loss: 23.219496726989746


  8%|▊         | 8/100 [01:41<19:26, 12.68s/it]

Model saved with loss: 19.127812385559082


  9%|▉         | 9/100 [01:53<19:11, 12.66s/it]

Model saved with loss: 16.084310054779053


 11%|█         | 11/100 [02:17<18:02, 12.16s/it]

Model saved with loss: 15.448028087615967


 12%|█▏        | 12/100 [02:30<18:03, 12.31s/it]

Model saved with loss: 14.445944786071777


 13%|█▎        | 13/100 [02:42<17:59, 12.41s/it]

Model saved with loss: 11.880154132843018


 14%|█▍        | 14/100 [02:55<17:53, 12.48s/it]

Model saved with loss: 11.83185863494873


 15%|█▌        | 15/100 [03:08<17:45, 12.53s/it]

Model saved with loss: 10.697999954223633


 16%|█▌        | 16/100 [03:20<17:34, 12.55s/it]

Model saved with loss: 9.999814510345459


 17%|█▋        | 17/100 [03:33<17:24, 12.58s/it]

Model saved with loss: 9.513706684112549


 18%|█▊        | 18/100 [03:46<17:12, 12.59s/it]

Model saved with loss: 8.96387767791748


 20%|██        | 20/100 [04:09<16:11, 12.15s/it]

Model saved with loss: 7.708811044692993


 21%|██        | 21/100 [04:22<16:12, 12.31s/it]

Model saved with loss: 7.457350015640259


 22%|██▏       | 22/100 [04:35<16:07, 12.41s/it]

Model saved with loss: 6.189404487609863


 24%|██▍       | 24/100 [04:58<15:16, 12.06s/it]

Model saved with loss: 5.70184063911438


 26%|██▌       | 26/100 [05:22<14:40, 11.89s/it]

Model saved with loss: 5.026899814605713


 27%|██▋       | 27/100 [05:35<14:45, 12.13s/it]

Model saved with loss: 4.467642545700073


 29%|██▉       | 29/100 [05:58<14:05, 11.92s/it]

Model saved with loss: 4.210372686386108


 31%|███       | 31/100 [06:22<13:34, 11.81s/it]

Model saved with loss: 3.5648547410964966


 33%|███▎      | 33/100 [06:46<13:08, 11.77s/it]

Model saved with loss: 3.2302368879318237


 35%|███▌      | 35/100 [07:10<12:43, 11.75s/it]

Model saved with loss: 2.9015170335769653


 37%|███▋      | 37/100 [07:33<12:19, 11.74s/it]

Model saved with loss: 2.5104525089263916


 38%|███▊      | 38/100 [07:46<12:24, 12.01s/it]

Model saved with loss: 2.403701663017273


 39%|███▉      | 39/100 [07:59<12:24, 12.20s/it]

Model saved with loss: 2.2716176509857178


 40%|████      | 40/100 [08:11<12:20, 12.35s/it]

Model saved with loss: 2.181091547012329


 41%|████      | 41/100 [08:24<12:14, 12.46s/it]

Model saved with loss: 2.1636529564857483


 42%|████▏     | 42/100 [08:37<12:05, 12.51s/it]

Model saved with loss: 2.0718968510627747


 43%|████▎     | 43/100 [08:49<11:55, 12.56s/it]

Model saved with loss: 1.9330343008041382


 44%|████▍     | 44/100 [09:02<11:45, 12.60s/it]

Model saved with loss: 1.807508945465088


 46%|████▌     | 46/100 [09:26<10:56, 12.16s/it]

Model saved with loss: 1.6663991212844849


 47%|████▋     | 47/100 [09:38<10:51, 12.30s/it]

Model saved with loss: 1.608429491519928


 48%|████▊     | 48/100 [09:51<10:45, 12.41s/it]

Model saved with loss: 1.5840855836868286


 49%|████▉     | 49/100 [10:04<10:36, 12.48s/it]

Model saved with loss: 1.4651144742965698


 52%|█████▏    | 52/100 [10:39<09:25, 11.78s/it]

Model saved with loss: 1.4021382927894592


 53%|█████▎    | 53/100 [10:51<09:25, 12.04s/it]

Model saved with loss: 1.3329032063484192


 56%|█████▌    | 56/100 [11:26<08:31, 11.63s/it]

Model saved with loss: 1.2881409227848053


 62%|██████▏   | 62/100 [12:34<07:06, 11.21s/it]

Model saved with loss: 1.2736429870128632


 63%|██████▎   | 63/100 [12:47<07:11, 11.66s/it]

Model saved with loss: 1.2558721601963043


 77%|███████▋  | 77/100 [15:28<04:37, 12.06s/it]


KeyboardInterrupt: 