In [5]:
import torch
import yaml
from models.models import EEG_ResNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config_path = 'pretrained/5s/config_xy.yaml'
weights_path = 'pretrained/5s/model_0_checkpoint.pt'

with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

mp = config["model"]
encoder = EEG_ResNet(
    in_channels=mp["in_channels"],
    conv1_params=mp["encoder_conv1_params"],
    n_blocks=mp["encoder_blocks"],
    res_params=mp["encoder_res_params"],
    res_pool_size=mp["encoder_pool_size"],
    dropout_p=mp["encoder_dropout_p"],
    res_dropout_p=mp["res_dropout_p"],
    proj_size=mp["ELM"]["eeg_proj_size"]
).to(device)

DDP = config["training"]["DDP"]
state_dict = torch.load(weights_path, map_location=device)
    
if DDP:
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = "module." + key
        new_state_dict[new_key] = value
    state_dict = new_state_dict
encoder.load_state_dict(state_dict)

encoder.eval()

batch_size = 4
n_channels = mp["in_channels"]
n_time_samples = mp["n_time_samples"]

synth_data = torch.randn(batch_size, n_channels, n_time_samples, device=device)

with torch.no_grad():
    emb, proj_emb = encoder(synth_data)
        
print(f"Representation shape from encoder: {emb.shape}")
print(f"Projected representation shape from encoder: {proj_emb.shape}")


Representation shape from encoder: torch.Size([4, 96])
Projected representation shape from encoder: torch.Size([4, 256])
