In [None]:
import numpy as np
import pandas as pd 
from pathlib import Path
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from transformers import AutoImageProcessor, ViTMAEConfig, ViTMAEModel, ViTMAEForPreTraining
from torch.utils.data import DataLoader, Dataset

from engine_hms_model import CustomDataset, JobConfig, ModelConfig 
from engine_hms_trainer import load_kaggle_data, TARGETS, TARGETS_PRED, BRAIN_ACTIVITY, DEVICE

In [None]:
print(TARGETS)
print(TARGETS_PRED)
print(BRAIN_ACTIVITY)
print(DEVICE)

In [None]:
train_easy, train_hard, all_specs, all_eegs = load_kaggle_data(JobConfig.PATHS, JobConfig.ENTROPY_SPLIT)

my_dataset = CustomDataset(
            train_easy, TARGETS, ModelConfig, all_specs, all_eegs, mode="train")

X, y = my_dataset[0]
print(X.shape)
print(y.shape)

In [None]:
class CustomMAE(nn.Module):
    def __init__(self, backbone="", num_classes=6, mlp_hidden_size=512, mae_dropout=0.05, mae_attention_dropout=0.05):
        super(CustomMAE, self).__init__()

        mae_config = ViTMAEConfig()
        mae_config.hidden_dropout_prob = mae_dropout
        mae_config.attention_probs_dropout_prob = mae_attention_dropout

        self.pre_processor = AutoImageProcessor.from_pretrained(backbone)
        self.vitmae = ViTMAEModel(mae_config).from_pretrained(backbone)
        self.mlp_head = nn.Sequential(
            nn.Linear(self.vitmae.config.hidden_size, mlp_hidden_size),
            nn.GELU(),
            nn.Linear(mlp_hidden_size, num_classes)
        )
    
    def __reshape_input(self, x):
        # Split the input into two halves
        # Concatenate each half along the height dimension
        concat_1 = torch.cat(torch.chunk(x[:, :4, :, :], 4, dim=1), dim=2)
        concat_2 = torch.cat(torch.chunk(x[:, 4:, :, :], 4, dim=1), dim=2)
        # Concatenate the two parts along the width dimension
        concatenated = torch.cat((concat_1, concat_2), dim=3)
        # Stack to get 3 channels and resize
        stacked = concatenated.repeat(1, 3, 1, 1)  # Replicate the single channel to get 3 channels
        resized = F.interpolate(stacked, size=(224, 224), mode='bilinear', align_corners=False)
        return resized
    
    def forward(self, x):
        x = self.__reshape_input(x)
        input_data = self.pre_processor(images=x, return_tensors="pt", padding=True)
        outputs = self.vitmae(**input_data)
        last_hidden_state = outputs.last_hidden_state
        logits = self.mlp_head(last_hidden_state[:, 0])
        return logits


In [None]:
model = CustomMAE(backbone='ModelConfig.BACKBONE', num_classes=6)