In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import torch
from torch import nn
import random

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

import os

from datetime import datetime

In [None]:
data_path = '/content/drive/MyDrive/data/'

In [None]:
train_betas_path = data_path + 'nsd_train_fmriavg_nsdgeneral_sub1.npy'
train_caps_embed_path = data_path + 'train_caps_embeds_sub1.npy'
train_caps_embed_neg_path = data_path + 'train_caps_embeds_negative_sub1.npy'

# train_betas = np.load(train_betas_path)
# train_caps_embed = np.load(train_caps_embed_path)
# train_caps_embed_neg = np.load(train_caps_embed_neg_path)

In [None]:
test_betas_path = data_path + 'test_betas_split_sub1.npy'
test_caps_embed_path = data_path + 'test_caps_embeds_split_sub1.npy'
test_caps_embed_neg_path = data_path + 'test_caps_embeds_negative_split_sub1.npy'

# test_betas = np.load(test_betas_path)
# test_caps_embed = np.load(test_caps_embed_path)
# test_caps_embed_neg = np.load(test_caps_embed_neg_path)

In [None]:
val_betas_path = data_path + 'val_betas_split_sub1.npy'
val_caps_embed_path = data_path + 'val_caps_embeds_split_sub1.npy'
val_caps_embed_neg_path = data_path + 'val_caps_embeds_split_negative_sub1.npy'

# test_betas = np.load(test_betas_path)
# test_caps_embed = np.load(test_caps_embed_path)
# test_caps_embed_neg = np.load(test_caps_embed_neg_path)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, betas_path, embeds_path):
        self.betas = torch.from_numpy(np.load(betas_path)).float().cuda()
        self.embeds = torch.from_numpy(
                np.squeeze(np.load(embeds_path))
            ).float().cuda()

    def __len__(self):
        return len(self.betas)

    def __getitem__(self, index):
        return self.betas[index], self.embeds[index]


class CustomDataLoader:
    def __init__(self, dataset, batch_size):
        self.data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    def __iter__(self):
        return iter(self.data_loader)

    def __len__(self):
        return len(self.data_loader)

batch_size = 64

train_dataset = CustomDataset(train_betas_path,
                              train_caps_embed_path,)
train_data_loader = CustomDataLoader(train_dataset, batch_size)

test_dataset = CustomDataset(test_betas_path,
                             test_caps_embed_path,)
test_data_loader = CustomDataLoader(test_dataset, batch_size)

val_dataset = CustomDataset(val_betas_path,
                            val_caps_embed_path,)
val_data_loader = CustomDataLoader(val_dataset, batch_size)

for beta, embed in train_data_loader:
    print(f"beta      {beta.shape}")
    print(f"embed     {embed.shape}")
    break

beta      torch.Size([64, 15724])
embed     torch.Size([64, 1280])


In [None]:
input_dim = 15724
print('input_dim:', input_dim)
ouput_dim = 1280
print('output_dim:', ouput_dim)

input_dim: 15724
output_dim: 1280


In [None]:
class RegressionModel(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.linear_regression = nn.Linear(in_features=input_channels,
                               out_features=output_channels)
        self.layer0 = 2**10 * 12
        self.layer1 = 2**10 * 10
        self.layer2 = 2**10 * 5
        self.layer3 = 2**10 * 2

        self.mlp = nn.Sequential(
            nn.Linear(in_features=input_channels,
                      out_features=self.layer0),
            nn.LeakyReLU(0.05, inplace=True),
            nn.Linear(in_features=self.layer0,
                      out_features=self.layer1),
            nn.LeakyReLU(0.05, inplace=True),
            nn.Linear(in_features=self.layer1,
                      out_features=self.layer2),
            nn.LeakyReLU(0.05, inplace=True),
            nn.Linear(in_features=self.layer2,
                      out_features=self.layer3),
            nn.LeakyReLU(0.05, inplace=True),
            nn.Linear(in_features=self.layer3,
                      out_features=output_channels),
        )


    def forward(self, x):
        return self.mlp(x)

model_embed = RegressionModel(input_dim, ouput_dim).cuda()

In [None]:
import torch.nn.functional as F

class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()

    def forward(self, y_true, y_pred):
        cosine_sim = F.cosine_similarity(y_true, y_pred, dim=1)
        loss = 1 - cosine_sim.mean()
        return loss

In [None]:
cos_loss = CosineSimilarityLoss()
mse_loss = nn.MSELoss(reduction='mean')
optim_embed = torch.optim.Adam(
    params=model_embed.parameters(), lr=0.0001, weight_decay=0.01)

In [None]:
run_name = 'mlp_one_model' #+ datetime.now().strftime("%Y-%m-%d_%H:%M")
writer = SummaryWriter(os.path.join(data_path, 'runs', run_name))

In [None]:
import torch
from tqdm import tqdm

num_epochs = 100
iter_index = 0
min_test_loss = 1e15
best_model_path = os.path.join(data_path, f'../models/{run_name}.pth')
test_iter_index = 0

for epoch in tqdm(range(num_epochs)):
    model_embed.train()
    for betas, embeds in train_data_loader:
        iter_index += 1
        pred_embeds = model_embed(betas)

        loss_embed = mse_loss(pred_embeds, embeds)
        optim_embed.zero_grad()
        loss_embed.backward()
        optim_embed.step()

        writer.add_scalars('train/losses', {
            'embed': loss_embed.item(),
        }, iter_index)

    test_loss = 0
    model_embed.eval()
    with torch.inference_mode():
        n_batches = 0
        for betas, embeds in val_data_loader:
            test_iter_index += 1
            pred_embeds = model_embed(betas)

            loss_embed = mse_loss(pred_embeds, embeds)

            cos_embed = cos_loss(pred_embeds, embeds)

            writer.add_scalars('val/losses', {
                'embed': loss_embed.item(),
            }, test_iter_index)

            writer.add_scalars('val/cos_sim', {
                'cos_embed': cos_embed.item(),
            }, test_iter_index)

            n_batches += 1
            test_loss += loss_embed.item()

    test_loss /= n_batches
    print(test_loss)
    if test_loss < min_test_loss:
        min_test_loss = test_loss
        print(f'Model saved with loss: {test_loss}')
        torch.save({
            'model_embed_state_dict': model_embed.state_dict(),
            'epoch': epoch,
            'min_test_loss': min_test_loss,
        }, best_model_path)

  0%|          | 0/100 [00:00<?, ?it/s]

0.8785734251141548
Model saved with loss: 0.8785734251141548


  1%|          | 1/100 [00:30<50:34, 30.65s/it]

0.876572422683239
Model saved with loss: 0.876572422683239


  2%|▏         | 2/100 [01:05<54:11, 33.18s/it]

0.8675332069396973
Model saved with loss: 0.8675332069396973


  3%|▎         | 3/100 [01:45<58:23, 36.11s/it]

0.8550610914826393
Model saved with loss: 0.8550610914826393


  4%|▍         | 4/100 [02:21<57:53, 36.18s/it]

0.844649150967598
Model saved with loss: 0.844649150967598


  5%|▌         | 5/100 [02:57<57:04, 36.04s/it]

0.8373319283127785
Model saved with loss: 0.8373319283127785


  6%|▌         | 6/100 [03:37<58:48, 37.54s/it]

0.8264186605811119
Model saved with loss: 0.8264186605811119


  7%|▋         | 7/100 [04:14<57:57, 37.40s/it]

0.818101592361927
Model saved with loss: 0.818101592361927


  8%|▊         | 8/100 [04:49<56:12, 36.66s/it]

0.8096858486533165
Model saved with loss: 0.8096858486533165


  9%|▉         | 9/100 [05:30<57:14, 37.74s/it]

0.7997043430805206
Model saved with loss: 0.7997043430805206


 10%|█         | 10/100 [06:07<56:16, 37.52s/it]

0.7914549931883812
Model saved with loss: 0.7914549931883812


 11%|█         | 11/100 [06:43<55:19, 37.29s/it]

0.7891509383916855
Model saved with loss: 0.7891509383916855


 12%|█▏        | 12/100 [08:20<1:21:01, 55.25s/it]

0.7813941314816475
Model saved with loss: 0.7813941314816475


 13%|█▎        | 13/100 [08:57<1:12:06, 49.73s/it]

0.7719646021723747
Model saved with loss: 0.7719646021723747


 14%|█▍        | 14/100 [09:33<1:05:31, 45.72s/it]

0.7704311236739159
Model saved with loss: 0.7704311236739159


 15%|█▌        | 15/100 [10:14<1:02:43, 44.28s/it]

0.7674928158521652
Model saved with loss: 0.7674928158521652


 17%|█▋        | 17/100 [11:23<52:56, 38.28s/it]  

0.7720538899302483
0.7553433999419212
Model saved with loss: 0.7553433999419212


 19%|█▉        | 19/100 [12:18<43:54, 32.53s/it]

0.757995568215847
0.7498874366283417
Model saved with loss: 0.7498874366283417


 21%|██        | 21/100 [13:55<50:36, 38.44s/it]

0.7509173527359962
0.749005064368248
Model saved with loss: 0.749005064368248


 22%|██▏       | 22/100 [14:29<48:07, 37.01s/it]

0.7459498271346092
Model saved with loss: 0.7459498271346092


 23%|██▎       | 23/100 [15:07<47:49, 37.27s/it]

0.7391054108738899
Model saved with loss: 0.7391054108738899


 25%|██▌       | 25/100 [16:35<48:34, 38.86s/it]

0.7418396174907684
0.7330734878778458
Model saved with loss: 0.7330734878778458


 27%|██▋       | 27/100 [17:37<41:43, 34.29s/it]

0.7369327545166016


 28%|██▊       | 28/100 [18:01<37:21, 31.13s/it]

0.7409880384802818
0.7227569371461868
Model saved with loss: 0.7227569371461868


 30%|███       | 30/100 [19:11<37:27, 32.11s/it]

0.7371014505624771


 31%|███       | 31/100 [19:35<34:03, 29.61s/it]

0.7230398431420326
0.7161630317568779
Model saved with loss: 0.7161630317568779


 33%|███▎      | 33/100 [21:12<41:23, 37.07s/it]

0.7173234894871712
0.715434268116951
Model saved with loss: 0.715434268116951


 34%|███▍      | 34/100 [22:23<51:49, 47.11s/it]

0.7149079069495201
Model saved with loss: 0.7149079069495201


 35%|███▌      | 35/100 [22:58<47:01, 43.40s/it]

0.7056737467646599
Model saved with loss: 0.7056737467646599


 36%|███▌      | 36/100 [24:10<55:25, 51.95s/it]

0.6992059275507927
Model saved with loss: 0.6992059275507927


 38%|███▊      | 38/100 [25:11<41:50, 40.49s/it]

0.705453410744667
0.6928930953145027
Model saved with loss: 0.6928930953145027


 40%|████      | 40/100 [26:17<35:47, 35.78s/it]

0.7108314633369446


 41%|████      | 41/100 [26:41<31:39, 32.19s/it]

0.7008620873093605
0.6857885792851448
Model saved with loss: 0.6857885792851448


 43%|████▎     | 43/100 [27:41<29:05, 30.62s/it]

0.70700853317976


 44%|████▍     | 44/100 [28:05<26:40, 28.57s/it]

0.6996202766895294


 45%|████▌     | 45/100 [28:29<24:50, 27.10s/it]

0.6954951882362366


 46%|████▌     | 46/100 [28:52<23:26, 26.05s/it]

0.7056652903556824
0.6810555681586266
Model saved with loss: 0.6810555681586266


 48%|████▊     | 48/100 [29:53<23:52, 27.55s/it]

0.68768011033535


 49%|████▉     | 49/100 [30:16<22:27, 26.41s/it]

0.6869299113750458


 50%|█████     | 50/100 [30:40<21:19, 25.58s/it]

0.6919002532958984


 51%|█████     | 51/100 [31:04<20:25, 25.00s/it]

0.6943465545773506


 52%|█████▏    | 52/100 [31:27<19:40, 24.59s/it]

0.6835731193423271


 53%|█████▎    | 53/100 [31:51<19:02, 24.32s/it]

0.6868076547980309


 54%|█████▍    | 54/100 [32:15<18:29, 24.12s/it]

0.6974243894219398


 55%|█████▌    | 55/100 [32:38<17:58, 23.98s/it]

0.7198136895895004


 56%|█████▌    | 56/100 [33:02<17:30, 23.88s/it]

0.712311677634716


 57%|█████▋    | 57/100 [33:25<17:03, 23.81s/it]

0.690217636525631
0.6724880710244179
Model saved with loss: 0.6724880710244179


 59%|█████▉    | 59/100 [34:28<18:22, 26.89s/it]

0.6788283437490463


 60%|██████    | 60/100 [34:52<17:18, 25.95s/it]

0.6851162239909172
0.6706938669085503
Model saved with loss: 0.6706938669085503


 62%|██████▏   | 62/100 [35:53<17:34, 27.75s/it]

0.6796326190233231


 63%|██████▎   | 63/100 [36:17<16:23, 26.57s/it]

0.6855711415410042


 64%|██████▍   | 64/100 [36:41<15:25, 25.70s/it]

0.677443690598011


 65%|██████▌   | 65/100 [37:04<14:37, 25.07s/it]

0.6773523688316345


 66%|██████▌   | 66/100 [37:28<13:58, 24.65s/it]

0.6795916333794594


 67%|██████▋   | 67/100 [37:51<13:23, 24.36s/it]

0.6877826526761055


 68%|██████▊   | 68/100 [38:15<12:52, 24.14s/it]

0.6750834509730339


 69%|██████▉   | 69/100 [38:39<12:23, 24.00s/it]

0.6873651370406151


 70%|███████   | 70/100 [39:02<11:56, 23.89s/it]

0.676732063293457


 71%|███████   | 71/100 [39:26<11:30, 23.82s/it]

0.6893286034464836


 72%|███████▏  | 72/100 [39:50<11:05, 23.77s/it]

0.6958615034818649


 73%|███████▎  | 73/100 [40:13<10:40, 23.73s/it]

0.67887382209301


 74%|███████▍  | 74/100 [40:37<10:16, 23.70s/it]

0.6748924478888512


 75%|███████▌  | 75/100 [41:01<09:52, 23.69s/it]

0.688378319144249


 76%|███████▌  | 76/100 [41:24<09:28, 23.68s/it]

0.6794803217053413


 77%|███████▋  | 77/100 [41:48<09:04, 23.68s/it]

0.6837047412991524


 78%|███████▊  | 78/100 [42:12<08:40, 23.68s/it]

0.6854564622044563


 79%|███████▉  | 79/100 [42:35<08:16, 23.66s/it]

0.6810959056019783


 80%|████████  | 80/100 [42:59<07:53, 23.65s/it]

0.6790306270122528
0.6686645597219467
Model saved with loss: 0.6686645597219467


 82%|████████▏ | 82/100 [44:00<07:56, 26.46s/it]

0.6754642724990845


 83%|████████▎ | 83/100 [44:23<07:16, 25.65s/it]

0.669106587767601


 84%|████████▍ | 84/100 [44:47<06:40, 25.05s/it]

0.6718477234244347


 85%|████████▌ | 85/100 [45:11<06:09, 24.61s/it]

0.6808960288763046
0.6680191978812218
Model saved with loss: 0.6680191978812218


 87%|████████▋ | 87/100 [46:07<05:38, 26.08s/it]

0.6976626291871071


 88%|████████▊ | 88/100 [46:31<05:04, 25.37s/it]

0.8028391003608704


 89%|████████▉ | 89/100 [46:55<04:33, 24.85s/it]

0.6813753470778465
0.6651092022657394
Model saved with loss: 0.6651092022657394


 91%|█████████ | 91/100 [47:50<03:54, 26.01s/it]

0.668582871556282


 92%|█████████▏| 92/100 [48:14<03:22, 25.32s/it]

0.6740228980779648


 93%|█████████▎| 93/100 [48:38<02:53, 24.82s/it]

0.6878287270665169


 94%|█████████▍| 94/100 [49:01<02:26, 24.47s/it]

0.6821145862340927


 95%|█████████▌| 95/100 [49:25<02:01, 24.23s/it]

0.6922249346971512


 96%|█████████▌| 96/100 [49:49<01:36, 24.05s/it]

0.6678513288497925


 97%|█████████▋| 97/100 [50:12<01:11, 23.93s/it]

0.6840250715613365


 98%|█████████▊| 98/100 [50:36<00:47, 23.84s/it]

0.6745445132255554


 99%|█████████▉| 99/100 [50:59<00:23, 23.78s/it]

0.6929002478718758


100%|██████████| 100/100 [51:23<00:00, 30.84s/it]

0.6764207854866982





In [None]:
with torch.inference_mode():
    n_batches = 0
    for betas, embeds in test_data_loader:
        test_iter_index += 1
        pred_embeds = model_embed(betas)

        loss_embed = mse_loss(pred_embeds, embeds)

        cos_embed = cos_loss(pred_embeds, embeds)

        writer.add_scalars('test/losses', {
            'embed': loss_embed.item(),
        }, test_iter_index)

        writer.add_scalars('test/cos_sim', {
            'cos_embed': cos_embed.item(),
        }, test_iter_index)

        n_batches += 1
        test_loss += loss_embed.item()