## Environment setup

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive

In [None]:
!pip install decord
!pip install albumentations -U
!pip install git+https://github.com/rwightman/pytorch-image-models

In [None]:
import sys
sys.path.insert(1, "/gdrive/My Drive/Projects/algonauts/src/mini-track/src_new/")

In [None]:
import os
import numpy as np

import torch
import timm
from tqdm.notebook import tqdm

In [None]:
import settings

In [None]:
!mkdir {settings.DATA_FOLDER}
!unzip -qq {settings.PROJECT_FOLDER}data/participants_data.zip -d {settings.DATA_FOLDER}

## Training helper functions

In [None]:
import model.utils as model_utils
import runner.utils as runner_utils
import data.utils as data_utils

import data.handler as data_handler
from data.dataset import VidDataset, VidFMRIDataset

from runner import runner_gpu as runner
from runner.loss import WeightedMSELoss
from runner.metric import vectorized_correlation

from torch.utils.data import DataLoader

### Train

In [None]:
def append_to_history(history_dict, target, output, loss, score):
    history_dict["loss"].append(loss)
    history_dict["score"].append(score)
    history_dict["outputs"].append(output)
    history_dict["targets"].append(target)


def train_and_validate(args):
    # SEED and split
    runner_utils.seed_everything(seed=settings.SEED)

    train_vid_files = args.vid_files[0:900]
    valid_vid_files = args.vid_files[900:1000]
    train_fmri_data = args.fmri_data[0:900, :, :]
    valid_fmri_data = args.fmri_data[900:1000, :, :]

    # create model
    model = model_utils.get_model(output_size=args.fmri_voxel_total, **args.model_params)
    model = model.to(args.device)
    # train data loader
    train_dataset = VidFMRIDataset(train_vid_files, train_fmri_data, transform=data_utils.get_train_transform(), fmri_transform="combination_augment")
    train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    # valid loader
    valid_dataset = VidFMRIDataset(valid_vid_files, valid_fmri_data, transform=data_utils.get_test_transform(), fmri_transform="mean_over_rep")
    valid_loader = DataLoader(valid_dataset, num_workers=args.num_workers, batch_size=1, shuffle=False)
    # loss, optimizer and scheduler
    train_criterion = WeightedMSELoss(reduction="mean")
    valid_criterion = torch.nn.MSELoss()
    optimizer = args.optimizer(model.parameters(), lr=args.lr, **args.optimizer_params)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
                    optimizer,
                    max_lr=args.lr,
                    epochs=args.epochs,
                    steps_per_epoch=len(train_loader),
                    div_factor=10,
                    final_div_factor=1,
                    pct_start=0.1,
                    anneal_strategy="cos",
                )
    # history dicts
    train_history = {"loss": [], "score": [], "outputs": [], "targets": []}
    valid_history = {"loss": [], "score": [], "outputs": [], "targets": []}
    model_history = {"best": {"state" : None, "score": 0, "epoch": -1}, 
                     "last": {"state" : None, "score": 0, "epoch": -1}}
    
    # train - validate - save
    for epoch in range(1, args.epochs +1):
        targets_all, outputs_all, loss, score = runner.train_epoch(args, model, train_loader, train_criterion, optimizer, scheduler, epoch)
        append_to_history(train_history, None, None, loss, score)
        runner_utils.print_score(outputs_all, targets_all, args.fmri_mapping, "\n\t")

        targets_all, outputs_all, loss, score = runner.validate(args, model, valid_loader, valid_criterion)
        append_to_history(valid_history, targets_all, outputs_all, loss, score)
        runner_utils.print_score(outputs_all, targets_all, args.fmri_mapping, "\t")

        model_history["last"]["state"] = model.state_dict()
        model_history["last"]["score"] = score
        model_history["last"]["epoch"] = epoch

        if (score > model_history["best"]["score"]):
            model_history["best"]["state"] = model.state_dict()
            model_history["best"]["score"] = score
            model_history["best"]["epoch"] = epoch

        output_history = {"train": train_history, "valid": valid_history, "mapping": args.fmri_mapping}
        data_handler.save_dict(output_history, args.output_valid_fn)
        torch.save(model_history, args.output_model_fn)

    args.model_history = model_history
    args.output_history = output_history

### Evaluate

In [None]:
def evaluate(args):
    valid_vid_files = args.vid_files[900:1000]
    valid_fmri_data = args.fmri_data[900:1000, :, :]

    valid_dataset = VidFMRIDataset(valid_vid_files, valid_fmri_data, transform=data_utils.get_test_transform(), fmri_transform="mean_over_rep")
    valid_loader = DataLoader(valid_dataset, num_workers=args.num_workers, batch_size=1, shuffle=False)

    model_state = torch.load(args.output_model_fn)
    model_state = model_state["last"]["state"]

    model = model_utils.get_model(output_size=args.fmri_voxel_total, **args.model_params)
    model = model.to(args.device)
    model.load_state_dict(model_state)

    valid_criterion = torch.nn.MSELoss()
    targets_all, outputs_all, loss, score = runner.validate(args, model, valid_loader, valid_criterion)
    runner_utils.print_score(outputs_all, targets_all, args.fmri_mapping, "\t")

### Predict

In [None]:
def predict(args):
    # data and dataset
    test_vid_files = args.test_vid_files
    test_dataset = VidDataset(test_vid_files, transform=data_utils.get_test_transform())
    test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=1, shuffle=False)

    # model
    model_state = torch.load(args.output_model_fn)
    model_state = model_state["last"]["state"]

    model = model_utils.get_model(output_size=args.fmri_voxel_total, **args.model_params)
    model = model.to(args.device)
    model.load_state_dict(model_state)

    # predict
    t = tqdm(test_loader)
    model.eval()
    outputs_all = []
    with torch.no_grad():
        for i, sample in enumerate(t):
            input  = sample['vid_data'].to(args.device)
            fps    = sample['fps'].to(args.device)
            output = model(input, fps)
            outputs_all.extend(output.detach().cpu().numpy())
    
    output = {"fmri_data": outputs_all, "mapping": args.fmri_mapping}
    data_handler.save_dict(output, args.output_fn)

### Setup and run

In [None]:
def generate_model_name(model_params):
    return f"{model_params['model_type']}-{model_params['backbone_name']}-{model_params['embed_size']}"

def generate_model_run_name(model_name, sub, rois):
    return f"{model_name}_{sub}_{'-'.join(rois)}"

def run_sub_roi(args, sub, train_flag=False, evaluate_flag=False, predict_flag=False):
    print(f"Running {sub}: train = {train_flag}, evaluate = {evaluate_flag}, predict = {predict_flag}")

    fmri_data = args.fmri_data[sub]

    # subject specific setup
    args.sub = sub
    args.rois = list(fmri_data["mapping"].keys())
    args.fmri_mapping = fmri_data["mapping"]
    args.fmri_data = fmri_data["data"]
    args.fmri_voxel_total = np.shape(args.fmri_data)[2]

    args.model_name = generate_model_name(args.model_params)
    model_run_name = generate_model_run_name(args.model_name, args.sub, args.rois)
    args.output_fn = f"{settings.OUTPUT_FOLDER}output_{model_run_name}.pkl"
    args.output_valid_fn = f"{settings.OUTPUT_FOLDER}output_valid_{model_run_name}.pkl"
    args.output_model_fn = f"{settings.OUTPUT_FOLDER}model_{model_run_name}.pt"

    if train_flag:
        print(f"Training: {model_run_name}")
        print(args.fmri_mapping)
        train_and_validate(args)

    if evaluate_flag:
        print(f"Evaluating: {model_run_name}")
        evaluate(args)
    
    if predict_flag:
        print(f"Predicting: {model_run_name}")
        predict(args)

    return args

## Train models and predict

In [None]:
fmri_data, train_vid_files, test_vid_files = data_utils.get_data()

In [None]:
class base_args:
    vid_files = train_vid_files
    test_vid_files = test_vid_files
    fmri_data = fmri_data
 
    lr = 1e-3
    epochs = 4
    batch_size = 4
    num_workers = 2
    device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class args(base_args):
    model_params = {
        "embed_size": 512,
        "model_type": "cnn-stats-adapt-1",
        "backbone_name": "eca_nfnet_l0",
        "linear_pool": None,
        "adaptive_pool": 1,
        "rnn_features": False,
    }
 
    optimizer = torch.optim.Adam
    optimizer_params = {}
 
for sub in settings.subs:
    _ = run_sub_roi(args(), sub, train_flag=True, predict_flag=True)

In [None]:
class args(base_args):
    model_params = {
        "embed_size": 1024,
        "model_type": "cnn-rnn-adapt-1-lin-6",
        "backbone_name": "resnet50",
        "linear_pool": 6,
        "adaptive_pool": 1,
        "rnn_features": True,
    }
 
    optimizer = torch.optim.AdamW
    optimizer_params = {
        "weight_decay": 0.02
    }
 
for sub in settings.subs:
    _ = run_sub_roi(args(), sub, train_flag=True, predict_flag=True)

## Ensemble predictions and create submission

In [None]:
import zipfile

model_names = ["cnn-stats-adapt-1", "cnn-rnn-adapt-1-lin-6"]
predictions = {}

for i, model_name in enumerate(model_names):
    predictions[model_name] = {}

    for sub in settings.subs: 
        model_full_name = f"{model_name}_{sub}_{'-'.join(settings.ROIs)}"
        args.output_fn = f"{settings.OUTPUT_FOLDER}output_{model_full_name}.pkl"
        preds = data_handler.load_dict(args.output_fn)
        predictions[model_name][sub] = preds


def get_roi_data(data, mapping, ROI):
    data = np.array(data)
    roi_mapping = mapping[ROI]
    return data[:, roi_mapping[0]:roi_mapping[1]]


results = {}
for ROI in settings.ROIs:
    ROI_results = {}
    for sub in settings.subs:
        ROI_results[sub] = None
        for model_name in model_names:
            if ROI_results[sub] is None:
                ROI_results[sub] = get_roi_data(predictions[model_name][sub]["fmri_data"], 
                                                predictions[model_name][sub]["mapping"], 
                                                ROI)
            else:
                ROI_results[sub] += get_roi_data(predictions[model_name][sub]["fmri_data"], 
                                                 predictions[model_name][sub]["mapping"], 
                                                 ROI)
            
            ROI_results[sub] /= len(model_names)

    results[ROI] = ROI_results

output_file = "/home/" + settings.track

data_handler.save_dict(results, output_file + ".pkl")
zipped_results = zipfile.ZipFile(output_file + ".zip", 'w')
zipped_results.write(output_file + ".pkl", settings.track + ".pkl")
zipped_results.close()

In [None]:
!mv {output_file}.zip {settings.OUTPUT_FOLDER}