In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch

from accelerate import notebook_launcher

from utils.data_utils import BrainDataset, process_file_last
from utils.train_utils import TrainConfig, run_train_model, count_parameters, load_model_weights

from dataclasses import dataclass
from simple_parsing.helpers import Serializable

import albumentations as A

import matplotlib.pyplot as plt
import gc

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.8 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
from models.brain_mae import EncoderConfig, Encoder, MAE, MAEConfig

### Init models

In [3]:
encoder_config = EncoderConfig()
mae_config = MAEConfig()
model = MAE(encoder_config=encoder_config, mae_config=mae_config)


EncoderConfig(window_size=16, n_electrodes=256, time_patch_size=4, n_features=4, n_layers=12, dim=512, hidden_dim=2048, head_dim=32, n_heads=16, n_kv_heads=16)
Simple Encoder: number of parameters: 50.50M
mae_config MAEConfig(masking_ratio=0.5, n_layers=6, dim=512, hidden_dim=2048, head_dim=32, n_heads=16, n_kv_heads=16)
Full MAE: number of parameters: 75.94M


In [4]:
data_path = Path("/drive/data/competitionData")

tokenize_function = None
process_file_function = process_file_last

window_size = encoder_config.window_size
n_electrodes = 256 * 4
max_tokens = 25

train_transform = A.Compose([
    A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True)
])



train_dataset = BrainDataset(data_path / 'train', 
                             process_file_function=process_file_function, 
                             tokenize_function=tokenize_function, 
                             transform=train_transform, 
                             max_tokens=max_tokens)

gc.collect()
test_dataset = BrainDataset(data_path / 'test', 
                            process_file_function=process_file_function, 
                            tokenize_function=tokenize_function, 
                            transform=test_transform, 
                            max_tokens=max_tokens)


submit_dataset = BrainDataset(data_path / 'competitionHoldOut', 
                            process_file_function=process_file_function, 
                            tokenize_function=tokenize_function, 
                            transform=test_transform, 
                            max_tokens=max_tokens)

gc.collect()


train_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset, submit_dataset])
test_dataset = submit_dataset
print(len(train_dataset))

Runed processing of the  /drive/data/competitionData/train


Processing files: 100%|██████████| 24/24 [00:33<00:00,  1.39s/file]


len of the dataset: 8800
max signal size: 906 | max tokens size: 87
median signal size: 297.0 | median tokens size: 31.0
Runed processing of the  /drive/data/competitionData/test


Processing files: 100%|██████████| 24/24 [00:03<00:00,  7.73file/s]


len of the dataset: 880
max signal size: 919 | max tokens size: 86
median signal size: 283.5 | median tokens size: 30.0
Runed processing of the  /drive/data/competitionData/competitionHoldOut


Processing files: 100%|██████████| 15/15 [00:04<00:00,  3.59file/s]


len of the dataset: 1200
max signal size: 594 | max tokens size: 8
median signal size: 290.5 | median tokens size: 8.0
10880


In [5]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

x, y, date = next(iter(test_dataloader))
# print(x.shape, y.shape, date.shape)

loss, logits = model(x, y, date)


### Load datasets

In [None]:
project_name = 'simple_mae'
save_folder = Path("/drive/logs/kovalev")

train_config = TrainConfig(exp_name='brain-mae-full-datasets-train',
                           mixed_precision=False,
                           batch_size=128, 
                           grad_accum=8,
                           num_workers=4,
                           pin_memory=True, 
                           eval_interval=1000, 
                           learning_rate=1e-4,
                           weight_decay=0.001, 
                           grad_clip=10,
                           lr_decay_iters=40_000, 
                           warmup_iters=500, 
                           project_name=project_name, 
                           save_folder=save_folder
                          )

# model = torch.compile(model)

# train_dataset = torch.utils.data.Subset(train_dataset, torch.arange(1).repeat(128))
# val_dataset = train_dataset
model = torch.compile(model)
args = (model, (train_dataset, test_dataset), train_config, encoder_config)
notebook_launcher(run_train_model, args, num_processes=1)

Launching training on one GPU.


dataloader_config = DataLoaderConfiguration(split_batches=True)
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler




***************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************


In [None]:
def visualize_data(x, recon, binary_mask):
    x_masked_input = binary_mask * x
    x_unmasked_input = ~binary_mask * x

    recon_masked = binary_mask * recon
    recon_unmasked = ~binary_mask * recon

    t_st, t_end = 0, 32 
    c_st, c_end = 256, 280

    fig, ax = plt.subplots(3,2, figsize = (16, 20))
    ax[0, 0].imshow(x[t_st:t_end, c_st:c_end].T, vmax=1, aspect='auto')
    ax[1, 0].imshow(x_unmasked_input[t_st:t_end, c_st:c_end].T, vmax=1, aspect='auto')
    ax[2, 0].imshow(x_masked_input[t_st:t_end, c_st:c_end].T, vmax=1, aspect='auto')

    ax[0, 1].imshow(recon[t_st:t_end, c_st:c_end].T, vmax=1, aspect='auto')
    ax[1, 1].imshow(recon_unmasked[t_st:t_end, c_st:c_end].T, vmax=1, aspect='auto')
    ax[2, 1].imshow(recon_masked[t_st:t_end, c_st:c_end].T, vmax=1, aspect='auto')


In [None]:
model = model.float()

In [None]:
model.dtype

In [None]:
x, txt, date_info = train_dataset[0]
# x = test_dataset.inputs[0]

# window_size = 16 

# n_samples = int(x.shape[0]//16)
# x_gt = x[:int(n_samples*window_size)] 

n_samples=1
recons = []
binary_mask =[]
for i in range(n_samples):
    x_slice = x[i*window_size: i+1*window_size]
    x_slice = torch.from_numpy(x_slice).to('cuda')[None, ...]
    x_slice = x_slice.float()

    date_info  = torch.from_numpy(date_info).to('cuda')[None, ...]
    loss, recon, binary = model(x_slice, date_info=date_info , return_preds = True)

    recon = recon[0].detach().cpu()
    binary = binary[0].detach().cpu()

    recons.append(recon)
    binary_mask.append(binary)
recons = torch.cat(recons, dim=1)
binary_mask = torch.cat(binary_mask, dim=1).to(torch.bool)

In [None]:
binary_mask.shape, recons.shape, x.shape

visualize_data(x, recons, binary_mask)

In [None]:
# x_gt = x_gt[None, ...]