In [None]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import h5py
from tqdm import tqdm

import webdataset as wds
import gc

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms

from accelerate import Accelerator, DeepSpeedPlugin

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils

global_batch_size = 64

In [None]:
### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1

accelerator = Accelerator(split_batches=False)

### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###

# if num_devices <= 1 and utils.is_interactive():
#     # can emulate a distributed environment for deepspeed to work in jupyter notebook
#     os.environ["MASTER_ADDR"] = "localhost"
#     os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
#     os.environ["RANK"] = "0"
#     os.environ["LOCAL_RANK"] = "0"
#     os.environ["WORLD_SIZE"] = "1"
#     os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
#     global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]

# # alter the deepspeed config according to your global and local batch size
# if local_rank == 0:
#     with open('deepspeed_config_stage2.json', 'r') as file:
#         config = json.load(file)
#     config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
#     config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
#     with open('deepspeed_config_stage2.json', 'w') as file:
#         json.dump(config, file)
# else:
#     # give some time for the local_rank=0 gpu to prep new deepspeed config file
#     time.sleep(10)
# deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)

In [None]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
num_workers = num_devices
print(accelerator.state)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
print = accelerator.print # only print if local_rank=0

# Config

In [None]:
batch_size = int(global_batch_size / num_devices)
epochs = 800
save_ckpt_freq = 50
model_name = "pretrain_videomae_base_patch16_224"
decoder_depth = 4
mask_type = "tube"
mask_ratio = .75 #ratio of the visual tokens/patches need be masked
input_size = [65,78,65]
drop_path = 0.0
normlize_target = True

lr = 1.5e-4
warmup_lr = 1e-6
min_lr = 1e-5
warmup_epochs = 40
warmup_steps = -1

use_checkpoint = False
resume = ""
auto_resume = False
start_epoch = 0

color_jitter = 0.0
train_interpolation = "bicubic"
data_urls = "/scratch/openneuro-0-100-ps13-f8-r1-bspline-shuffled/func-{000000..000577}.tar"
data_location = "aws"
data_resample = True
data_seed = 42
data_buffer_size = 100 #buffer size for shuffling the datasets
data_batch_per_epoch = 1000
max_num_patches = 196*4
num_frames = 8
tubelet_size = 2
sampling_rate = 1
output_dir = ""
log_dir = None
seed = 0

device = "cuda"
num_workers = 1
pin_mem = True
world_size = 1
local_rank = 0

In [None]:
import timm
model = timm.create_model(
        model_name,
        pretrained=False,
        drop_path_rate=drop_path,
        drop_block_rate=None,
        decoder_depth=decoder_depth,
        use_checkpoint=use_checkpoint
    )

patch_size = model.encoder.patch_embed.patch_size
print("Patch size", patch_size)

window_size = (num_frames // tubelet_size, input_size[0] // 
               patch_size[0], input_size[1] // patch_size[1], input_size[2] // patch_size[2])
print("Window size", window_size)

# Dataloader

In [None]:
from masking_generator import TubeMaskingGenerator, AgnosticMaskingGenerator
class DataAugmentationForVideoMAE():
    def __init__(self, args):
        # TODO: add augmentation for fMRI
        self.transform = transforms.Compose([
            # tio.CropOrPad((65, 78, 65)),
            # tio.RandomFlip(axes=('LR',)),
            # tio.RandomAffine(scales=(0.9, 1.1), degrees=10, isotropic=False, default_pad_value='otsu'),
            # tio.RandomAffine(scales=(0.9, 1.1)),
            # tio.RandomNoise(std=(0, 0.1)),
            # tio.RandomBlur(std=(0, 0.1)),
            # tio.RandomBiasField(coefficients=(0, 0.1)),
            ])
        self.patchify = Patchify(patch_size, tubelet_size, max_num_patches)
        if mask_type == 'tube':
            self.masked_position_generator = TubeMaskingGenerator(
                window_size, mask_ratio, max_num_patches
            )
        elif mask_type == 'random':
            self.masked_position_generator = AgnosticMaskingGenerator(
                window_size, mask_ratio, max_num_patches
            )

    def __call__(self, sample):
        key, func = sample
        process_data = self.transform(func)
        paded, mask, token_shape = self.patchify(process_data)
        return paded, token_shape, mask, self.masked_position_generator(token_shape)

    def __repr__(self):
        repr = "(DataAugmentationForVideoMAE,\n"
        repr += "  transform = %s,\n" % str(self.transform)
        repr += "  Masked position generator = %s,\n" % str(self.masked_position_generator)
        repr += ")"
        return repr
transform = DataAugmentationForVideoMAE()

In [None]:
urls = f"pipe:aws s3 cp {data_urls} -"
print(urls)

dataset_train = wds.WebDataset(urls, resampled=data_resample).decode("torch")\
    .to_tuple("__key__", "func.npy")\
    .shuffle(data_buffer_size, initial=data_buffer_size, rng=random.Random(data_seed))\
    .map(transform)\
    .batched(batch_size, partial=False)\
    .with_epoch(data_batch_per_epoch)
print("Data Aug = %s" % str(transform))