In [1]:
import os

from glob import glob
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from tqdm.notebook import tqdm

from utils import *

device = torch.device('cuda')

_RANDOM_STATE = 42
    
seed_everything(_RANDOM_STATE)

In [2]:
video_paths = sorted(glob('../../Algonauts2021_devkit/AlgonautsVideos268_All_30fpsmax/*.mp4'))
audio_paths = sorted(glob('../../Algonauts2021_devkit/AlgonautsVideos268_All_30fpsmax/*.wav'))

if audio_paths is None:
    for video_path in video_paths:
        extract_audios(video_path)

print(f'Found {len(video_paths)} videos, {len(audio_paths)} audios')

Found 1102 videos, 1068 audios


In [30]:
def _extract_densenet_embeddings(model, video_paths, mode='mean'):
    """Extract embeddings from an ImageNet-pretrained model.
    
    Future implementation will be more name-specific and include per-block feature extraction, e.g.:
    
    if model_name == 'densenet':
        for block_idx in len(densenet_blocks):
            extract_features_per_densenet_block(model, input)
    ...
    """
    resize_normalize = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    
    embeddings = []
    for video_path in tqdm(video_paths, desc='Extracting video features...', total=len(video_paths)):
        vid, num_frames = sample_video_from_mp4(video_path)

        embedding = []
        
        for frame, img in enumerate(vid):
            input_img = torch.autograd.Variable(resize_normalize(img).unsqueeze(0))
                
            frame_embedding = model(input_img.cuda()).flatten()
            
            embedding.append(frame_embedding.cpu().detach())

        if mode == 'mean':
            embeddings.append(torch.stack(embedding).mean(0).numpy())
        elif mode == 'statpool':
            embeddings.append(torch.cat([torch.stack(embedding).mean(0), torch.stack(embedding).std(0)]).numpy())
        else:
            raise ValueError('Embedding aggregation mode not found.')
    
    embeddings = np.array(embeddings, dtype=np.float32)
    return embeddings

In [31]:
# Simple way instead of extracting at each block
# Extract features just before classification
# Don't know whether per-block feature extraction is better, but this is a simpler solution in the meantime

image_model = torchvision.models.densenet169(pretrained=True)
image_model.features.norm5 = nn.Identity()
image_model.classifier = nn.Identity()

image_model.eval()
image_model.to(device)

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [41]:
def _get_densenet169_blocks(idx):
    assert idx in [-1, -3, -5, -7]
    
    layers = list(image_model.features.children())[:idx]
    layers += [nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(start_dim=1)]
    
    return nn.Sequential(*layers)

In [42]:
block_embeddings = []

for idx in [-1, -3, -5, -7]:
    model = _get_densenet169_blocks(idx)
    
    block_embeddings.append(_extract_densenet_embeddings(model, video_paths))

Extracting video features...:   0%|          | 0/1102 [00:00<?, ?it/s]

Extracting video features...:   0%|          | 0/1102 [00:00<?, ?it/s]

Extracting video features...:   0%|          | 0/1102 [00:00<?, ?it/s]

Extracting video features...:   0%|          | 0/1102 [00:00<?, ?it/s]

In [43]:
for b in block_embeddings:
    print(b.shape)

(1102, 1664)
(1102, 1280)
(1102, 512)
(1102, 256)


In [45]:
np.save(f'densenet169_concat_all_blocks', np.concatenate(block_embeddings, axis=-1))

In [44]:
for b in block_embeddings:
    np.save(f'densenet169_img_embeddings_{b.shape[-1]}', b)

In [22]:
all_embeddings = np.load('densenet169_img_embeddings.npy')

train_embeddings, test_embeddings = np.split(img_embeddings, [1000], axis=0)

print(all_embeddings.shape)

(1102, 1664)


In [6]:
sub_folders = sorted(os.listdir('../../Algonauts2021_devkit/participants_data_v2021/mini_track/'))
mini_track_ROIS = sorted(list(map(lambda x: Path(x).stem, os.listdir('../../Algonauts2021_devkit/participants_data_v2021/mini_track/sub01/'))))

def _load_data(sub, ROI, fmri_dir = '../../Algonauts2021_devkit/participants_data_v2021', batch_size=128, use_gpu=True):
    if ROI == "WB":
        track = "full_track"
    else:
        track = "mini_track"
    
    fmri_dir = os.path.join(fmri_dir, track)
    sub_fmri_dir = os.path.join(fmri_dir, sub)
    results_dir = os.path.join('../results/', f'{image_model.__class__.__name__}', track, sub)
    
    if track == "full_track":
        fmri_train_all, voxel_mask = get_fmri(sub_fmri_dir, ROI)
    else:
        fmri_train_all = get_fmri(sub_fmri_dir, ROI)
    
    return fmri_train_all

In [15]:
!ls ../../Algonauts2021_devkit/participants_data_v2021/mini_track/sub04/EBA.pkl

../../Algonauts2021_devkit/participants_data_v2021/mini_track/sub04/EBA.pkl


In [24]:
sub = 'sub04'

ROI = mini_track_ROIS[0]

train_fmri = _load_data(sub, ROI)

{'train': array([[[-6.95004166e-01, -2.33046421e-01, -3.62619130e-02, ...,
         -1.46361185e+00, -2.65478392e+00,  6.40064911e-01],
        [-2.81774627e-01, -9.27002327e-01,  9.01490036e-02, ...,
         -1.89008356e-03,  4.03722716e-01,  6.12871734e-01],
        [ 2.58116807e-02, -7.02955570e-01,  7.76018936e-01, ...,
         -4.09921489e-01, -3.87421912e-01,  2.42603250e-01]],

       [[-1.31430813e+00, -6.14503634e-01, -7.77736832e-02, ...,
         -2.28237866e-01,  7.09095740e-02,  4.78120847e-01],
        [-2.75905393e-01, -5.44412503e-01,  3.63831838e-01, ...,
         -8.33588969e-02,  6.37616045e-02, -8.95454289e-01],
        [ 2.94103957e-01,  1.21511936e+00,  5.27384531e-02, ...,
          1.60667541e+00, -4.57609457e-02, -5.50046167e-02]],

       [[ 4.14587515e-01,  1.46559733e-01,  1.12509313e+00, ...,
         -6.04523029e-01, -8.07886635e-01, -5.66291811e-01],
        [-8.33496618e-03,  5.15193824e-01,  4.66992968e-01, ...,
          2.99543042e-01,  1.17609743e+

In [25]:
train_embeddings, val_embeddings, train_voxels, val_voxels = train_test_split(train_embeddings, train_fmri)

In [7]:
class MultiVoxelAutoEncoder(pl.LightningModule):
    def __init__(self, in_features=1668, out_features=368):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, out_features)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(out_features, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, in_features)
        )
    
    def forward(self, x):
        x = self.encoder(x)
        # x = self.decoder(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [8]:
model = MultiVoxelAutoEncoder()

summary(model.cuda(), input_size=(1,1668))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 1, 1024]       1,709,056
              ReLU-2              [-1, 1, 1024]               0
            Linear-3               [-1, 1, 512]         524,800
              ReLU-4               [-1, 1, 512]               0
            Linear-5               [-1, 1, 368]         188,784
Total params: 2,422,640
Trainable params: 2,422,640
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.03
Params size (MB): 9.24
Estimated Total Size (MB): 9.27
----------------------------------------------------------------


In [38]:
class VoxelDataset(Dataset):
    def __init__(self, features, voxel_maps=None, transform=None):
        self.features = features
        
        self.voxel_maps = voxel_maps
        
        self.transform = transform
    
    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feat = self.features[idx]
        
        if self.voxel_maps:
            voxel_map = self.voxel_maps[idx]
        
        if self.transform:
            feat = self.transform(feat)
        
        return feat, voxel_map    

In [39]:
train_dataset = VoxelDataset(train_embeddings, train_voxels)
val_dataset = VoxelDataset(val_embeddings, val_voxels)

train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)

In [40]:
mvae = MultiVoxelAutoEncoder()

In [43]:
trainer = pl.Trainer(gpus=2)
trainer.fit(mvae, train_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


KeyboardInterrupt: 