In [1]:
import numpy as np
import torch
from torch import nn
import random

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

import os

from datetime import datetime

In [2]:
train_betas_path = '../data/processed_data/subj01/nsd_train_fmriavg_nsdgeneral_sub1.npy'
train_caps_embed_path = '../data/caps_embeds/train_caps_embeds_sub1.npy'
train_caps_embed_neg_path = '../data/caps_embeds/train_caps_embeds_negative_sub1.npy'

# train_betas = np.load(train_betas_path)
# train_caps_embed = np.load(train_caps_embed_path)
# train_caps_embed_neg = np.load(train_caps_embed_neg_path)

In [3]:
test_betas_path = '../data/processed_data/subj01/nsd_test_fmriavg_nsdgeneral_sub1.npy'
test_caps_embed_path = '../data/caps_embeds/test_caps_embeds_sub1.npy'
test_caps_embed_neg_path = '../data/caps_embeds/test_caps_embeds_negative_sub1.npy'

# test_betas = np.load(test_betas_path)
# test_caps_embed = np.load(test_caps_embed_path)
# test_caps_embed_neg = np.load(test_caps_embed_neg_path)

In [4]:
class CustomDataset(Dataset):
    def __init__(self, betas_path, embeds_path, neg_embeds_path):
        self.betas = torch.from_numpy(np.load(betas_path)).float().cuda()
        self.embeds = torch.from_numpy(
            np.squeeze(np.load(embeds_path))
            ).float().cuda()
        self.embeds_neg = torch.from_numpy(
            np.squeeze(np.load(neg_embeds_path))
            ).float().cuda()

    def __len__(self):
        return len(self.betas)

    def __getitem__(self, index):
        return self.betas[index], self.embeds[index], self.embeds_neg[index]


class CustomDataLoader:
    def __init__(self, dataset, batch_size):
        self.data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    def __iter__(self):
        return iter(self.data_loader)

    def __len__(self):
        return len(self.data_loader)

batch_size = 64

train_dataset = CustomDataset(train_betas_path,
                        train_caps_embed_path,
                        train_caps_embed_neg_path)
train_data_loader = CustomDataLoader(train_dataset, batch_size)

test_dataset = CustomDataset(test_betas_path,
                              test_caps_embed_path,
                              test_caps_embed_neg_path)
test_data_loader = CustomDataLoader(test_dataset, batch_size)

for beta, embed, neg_embed in train_data_loader:
    print(f"beta      {beta.shape}")
    print(f"embed     {embed.shape}")
    print(f"neg_embed {neg_embed.shape}")
    break

beta      torch.Size([64, 15724])
embed     torch.Size([64, 1280])
neg_embed torch.Size([64, 1280])


In [5]:
input_dim = 15724
print('input_dim:', input_dim)
ouput_dim = 1280
print('output_dim:', ouput_dim)

input_dim: 15724
output_dim: 1280


In [6]:
class LinearRegression(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.layer = nn.Linear(in_features=input_channels,
                               out_features=output_channels)
        
    def forward(self, x):
        return self.layer(x)

model_embed = LinearRegression(input_dim, ouput_dim).cuda()
model_neg = LinearRegression(input_dim, ouput_dim).cuda()

In [7]:
next(model_embed.parameters()).dtype

torch.float32

In [8]:
mse_loss = nn.MSELoss()
optim_embed = torch.optim.Adam(params=model_embed.parameters(), lr=0.0001)
optim_neg = torch.optim.Adam(params=model_neg.parameters(), lr=0.0001)

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
run_name = 'first' #+ datetime.now().strftime("%Y-%m-%d_%H:%M")
writer = SummaryWriter(os.path.join('runs', run_name))

In [10]:
num_epoches = 100
iter_index = 0
test_iter_index = 0
min_test_loss = 1e15
best_model_path = '../models/best_model.pth'
for epoch in tqdm(range(num_epoches)):
    for betas, embeds, neg_embeds in train_data_loader:
        iter_index += 1
        pred_embeds = model_embed(betas)
        pred_neg_embeds = model_neg(betas)

        loss_embed = mse_loss(pred_embeds, embeds)
        optim_embed.zero_grad()
        loss_embed.backward()
        optim_embed.step()
        
        loss_neg = mse_loss(pred_neg_embeds, neg_embeds)
        optim_neg.zero_grad()
        loss_neg.backward()
        optim_neg.step()

        writer.add_scalars('train/losses', {
            'embed': loss_embed.item(),
            'negat': loss_neg.item(),
        }, iter_index)

    test_loss = 0
    for betas, embeds, neg_embeds in test_data_loader:
        test_iter_index += 1
        pred_embeds = model_embed(betas)
        pred_neg_embeds = model_neg(betas)

        loss_embed = mse_loss(pred_embeds, embeds)
        
        loss_neg = mse_loss(pred_neg_embeds, neg_embeds)

        writer.add_scalars('test/losses', {
            'embed': loss_embed.item(),
            'negat': loss_neg.item(),
        }, test_iter_index)
        test_loss += loss_embed.item() + loss_neg.item()

    print(test_loss)
    if test_loss < min_test_loss:
        min_test_loss = test_loss
        torch.save({
                'model_embed_state_dict': model_embed.state_dict(),
                'model_neg_state_dict': model_neg.state_dict(),
            }, best_model_path)

  0%|          | 0/100 [00:00<?, ?it/s]

582457.1640625


  1%|          | 1/100 [00:03<05:39,  3.43s/it]

469532.1123046875


  2%|▏         | 2/100 [00:06<05:32,  3.39s/it]

400187.9951171875


  3%|▎         | 3/100 [00:10<05:26,  3.37s/it]

360722.3310546875


  4%|▍         | 4/100 [00:13<05:22,  3.36s/it]

339101.12890625


  5%|▌         | 5/100 [00:16<05:19,  3.36s/it]

318082.435546875


  6%|▌         | 6/100 [00:20<05:15,  3.35s/it]

292881.2958984375


  7%|▋         | 7/100 [00:23<05:12,  3.36s/it]

281820.24609375


  8%|▊         | 8/100 [00:26<05:08,  3.36s/it]

274581.10107421875


  9%|▉         | 9/100 [00:30<05:07,  3.38s/it]

258297.6591796875


 10%|█         | 10/100 [00:33<05:01,  3.35s/it]

251225.12158203125


 11%|█         | 11/100 [00:36<04:57,  3.35s/it]

242389.40380859375


 12%|█▏        | 12/100 [00:40<04:54,  3.34s/it]

235863.9267578125


 14%|█▍        | 14/100 [00:46<04:36,  3.22s/it]

236093.1513671875
222917.7802734375


 15%|█▌        | 15/100 [00:49<04:37,  3.26s/it]

216903.5107421875


 16%|█▌        | 16/100 [00:53<04:36,  3.29s/it]

211149.5458984375


 18%|█▊        | 18/100 [00:59<04:22,  3.20s/it]

212412.31591796875
206852.802734375


 19%|█▉        | 19/100 [01:02<04:22,  3.24s/it]

206025.99951171875


 21%|██        | 21/100 [01:09<04:10,  3.18s/it]

214010.92578125
200739.50537109375


 22%|██▏       | 22/100 [01:12<04:11,  3.22s/it]

190703.806640625


 24%|██▍       | 24/100 [01:18<04:01,  3.17s/it]

196803.0869140625


 25%|██▌       | 25/100 [01:21<03:52,  3.10s/it]

194843.70751953125
185470.20263671875


 27%|██▋       | 27/100 [01:28<03:47,  3.12s/it]

186010.2490234375


 28%|██▊       | 28/100 [01:31<03:40,  3.07s/it]

187725.6494140625
183808.99755859375


 29%|██▉       | 29/100 [01:34<03:44,  3.16s/it]

180436.24853515625


 31%|███       | 31/100 [01:40<03:36,  3.13s/it]

197662.41552734375
179381.0439453125


 32%|███▏      | 32/100 [01:44<03:37,  3.20s/it]

178799.9951171875


 33%|███▎      | 33/100 [01:47<03:37,  3.25s/it]

173980.24658203125


 35%|███▌      | 35/100 [01:53<03:27,  3.19s/it]

179154.3896484375
170629.0771484375


 36%|███▌      | 36/100 [01:57<03:27,  3.25s/it]

167573.59326171875


 38%|███▊      | 38/100 [02:03<03:17,  3.18s/it]

173166.8095703125


 39%|███▉      | 39/100 [02:06<03:09,  3.11s/it]

183672.61376953125


 40%|████      | 40/100 [02:09<03:03,  3.06s/it]

180437.34619140625


 41%|████      | 41/100 [02:12<02:58,  3.02s/it]

199381.76953125
155591.96044921875


 43%|████▎     | 43/100 [02:18<02:54,  3.07s/it]

169411.28271484375


 44%|████▍     | 44/100 [02:21<02:49,  3.02s/it]

169087.4921875


 45%|████▌     | 45/100 [02:24<02:44,  3.00s/it]

166136.69921875


 46%|████▌     | 46/100 [02:27<02:41,  2.99s/it]

159064.6142578125


 47%|████▋     | 47/100 [02:30<02:37,  2.97s/it]

160723.58984375


 48%|████▊     | 48/100 [02:33<02:34,  2.96s/it]

155816.484375


 49%|████▉     | 49/100 [02:36<02:30,  2.95s/it]

161526.67138671875


 50%|█████     | 50/100 [02:39<02:27,  2.96s/it]

160069.923828125
155133.28125


 52%|█████▏    | 52/100 [02:45<02:27,  3.06s/it]

160565.65185546875


 53%|█████▎    | 53/100 [02:48<02:22,  3.04s/it]

158905.8974609375
152360.93408203125


 54%|█████▍    | 54/100 [02:51<02:24,  3.14s/it]

150346.02880859375


 56%|█████▌    | 56/100 [02:58<02:17,  3.13s/it]

159070.8466796875


 57%|█████▋    | 57/100 [03:01<02:12,  3.07s/it]

156279.83642578125


 58%|█████▊    | 58/100 [03:04<02:07,  3.03s/it]

163300.98681640625


 59%|█████▉    | 59/100 [03:07<02:03,  3.00s/it]

155483.46923828125
148507.29736328125


 60%|██████    | 60/100 [03:10<02:03,  3.10s/it]

146823.28125


 62%|██████▏   | 62/100 [03:16<01:57,  3.10s/it]

153592.02783203125


 63%|██████▎   | 63/100 [03:19<01:52,  3.05s/it]

212659.6650390625


 64%|██████▍   | 64/100 [03:22<01:48,  3.02s/it]

148902.62353515625


 65%|██████▌   | 65/100 [03:25<01:44,  2.99s/it]

155387.91015625


 66%|██████▌   | 66/100 [03:28<01:41,  2.98s/it]

147270.875


 67%|██████▋   | 67/100 [03:31<01:37,  2.96s/it]

147463.0361328125
144571.6953125


 69%|██████▉   | 69/100 [03:37<01:34,  3.04s/it]

152548.11083984375


 70%|███████   | 70/100 [03:40<01:30,  3.01s/it]

158142.04052734375
139981.48486328125


 72%|███████▏  | 72/100 [03:46<01:25,  3.06s/it]

148228.37060546875


 73%|███████▎  | 73/100 [03:49<01:21,  3.02s/it]

165025.751953125


 74%|███████▍  | 74/100 [03:52<01:17,  2.99s/it]

144091.9541015625


 75%|███████▌  | 75/100 [03:55<01:14,  2.97s/it]

167456.068359375


 76%|███████▌  | 76/100 [03:58<01:11,  2.96s/it]

151405.990234375


 77%|███████▋  | 77/100 [04:01<01:07,  2.95s/it]

146872.083984375


 78%|███████▊  | 78/100 [04:04<01:04,  2.95s/it]

145855.6611328125


 79%|███████▉  | 79/100 [04:07<01:01,  2.95s/it]

147031.45263671875


 80%|████████  | 80/100 [04:10<00:58,  2.94s/it]

143509.00830078125


 81%|████████  | 81/100 [04:13<00:55,  2.94s/it]

176437.5537109375


 82%|████████▏ | 82/100 [04:16<00:52,  2.94s/it]

144073.63427734375


 83%|████████▎ | 83/100 [04:19<00:49,  2.94s/it]

144152.53466796875


 84%|████████▍ | 84/100 [04:22<00:46,  2.94s/it]

144153.65185546875
137484.76611328125


 86%|████████▌ | 86/100 [04:28<00:42,  3.03s/it]

142843.50146484375


 87%|████████▋ | 87/100 [04:31<00:39,  3.00s/it]

145781.54931640625


 88%|████████▊ | 88/100 [04:34<00:35,  2.98s/it]

149998.33447265625


 89%|████████▉ | 89/100 [04:37<00:32,  2.97s/it]

145496.78271484375


 90%|█████████ | 90/100 [04:40<00:29,  2.96s/it]

138511.50390625
137338.4736328125


 92%|█████████▏| 92/100 [04:46<00:24,  3.04s/it]

149093.49658203125


 93%|█████████▎| 93/100 [04:49<00:21,  3.01s/it]

138579.5673828125


 94%|█████████▍| 94/100 [04:52<00:17,  2.98s/it]

145244.4736328125


 95%|█████████▌| 95/100 [04:55<00:14,  2.97s/it]

142653.41455078125


 96%|█████████▌| 96/100 [04:58<00:11,  2.96s/it]

150648.8720703125


 97%|█████████▋| 97/100 [05:01<00:08,  2.96s/it]

153273.9287109375
136520.29272460938


 99%|█████████▉| 99/100 [05:07<00:03,  3.03s/it]

164515.69091796875


100%|██████████| 100/100 [05:10<00:00,  3.10s/it]

157114.84228515625



