In [1]:
from torch.utils.data import Dataset
import os
import h5py
import numpy as np
from collections import defaultdict
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from vector_quantize_pytorch import VectorQuantize
from torchmetrics.functional import pearson_corrcoef
from pathlib import Path

from utils import load_fmri, align_features_and_fmri_samples

In [5]:
class AlgonautsDataset(Dataset):
    def __init__(self, features_dir, fmri_dir, movies, subject, excluded_samples_start=5, excluded_samples_end=5, hrf_delay=3, stimulus_window=5):
        self.features_dir = features_dir
        self.fmri_dir = fmri_dir
        self.movies = movies
        self.subject = subject
        self.excluded_samples_start = excluded_samples_start
        self.excluded_samples_end = excluded_samples_end
        self.hrf_delay = hrf_delay
        self.stimulus_window = stimulus_window
        self.partition_indices = defaultdict(list)
        
        # First load all raw features
        stimuli_features = {"visual": {}, "audio": {}, "language": {}}
        # Load audio and video features first
        for movie in self.movies:
            if 'friends' in movie:
                season = movie.split('-')[1]
                dir_list = sorted(os.listdir(self.features_dir + 'audio')) #List of all audio for each subset of dataset
                for episode in dir_list:
                    if f"{season}e" in episode and '_features_' in episode:
                        episode_base = episode.split('_features_')[0] # friends_s01e01 and so on....
                        
                        for modality in ['audio', 'visual']:
                            with h5py.File(os.path.join(self.features_dir, modality, f"{episode_base}_features_{modality}.h5"), 'r') as f:
                                try:
                                    stimuli_features[modality][episode_base.split('_')[1]] = f[episode_base.split('_')[1]][modality][:]
                                except:
                                    f.visit(lambda x: print(x))
                lang_dir_list = sorted(os.listdir(self.features_dir + 'language'))
                for episode in lang_dir_list:
                    if f"{season}e" in episode and '_features_' in episode:
                        episode_base = episode.split('_features_')[0]
                        
                        with h5py.File(os.path.join(self.features_dir, 'language', f"{episode_base}_features_language.h5"), 'r') as f:
                            try:
                                st_season_episode = episode_base.split('_')[1]
                                stimuli_features['language'][st_season_episode] = f[st_season_episode]['language_pooler_output'][:]
                            except:
                                f.visit(lambda x: print(x))
            else:
                movie_name = movie.replace('movie10-', '')
                partitions = sorted([f for f in os.listdir(self.features_dir + 'audio') if movie_name in f and '_features_' in f])
                
                for partition in partitions:
                    partition_base = partition.split('_features_')[0]
                    
                    for modality in ['audio', 'visual']:
                        with h5py.File(os.path.join(self.features_dir, modality, f"{partition_base}_features_{modality}.h5"), 'r') as f:
                            try:
                                stimuli_features[modality][partition_base] = f[partition_base][modality][:]
                            except:
                                f.visit(lambda x: print(x))
                lang_partitions = sorted([f for f in os.listdir(self.features_dir + 'language') if movie_name in f and '_features_' in f])
                
                for partition in lang_partitions:
                    partition_base = partition.split('_features_')[0]
                    
                    with h5py.File(os.path.join(self.features_dir, 'language', f"{partition_base}_features_language.h5"), 'r') as f:
                        try:
                            stimuli_features['language'][partition_base] = f[partition_base]['language_pooler_output'][:]
                        except:
                            f.visit(lambda x: print(x))

        fmri_data = load_fmri(self.fmri_dir, self.subject)

        self.aligned_features, self.aligned_fmri = align_features_and_fmri_samples(
            stimuli_features, 
            fmri_data, 
            self.excluded_samples_start, 
            self.excluded_samples_end, 
            self.hrf_delay, 
            self.stimulus_window, 
            self.movies
        )
        
    def __len__(self):
        return self.aligned_features['audio'].shape[0]

    def __getitem__(self, idx):
        return {
            'audio': self.aligned_features['audio'][idx],
            'video': self.aligned_features['visual'][idx],
            'language': self.aligned_features['language'][idx],
            'fmri': self.aligned_fmri[idx]
        }
        


In [6]:
features_dir = '/home/pranav/mihir/algonauts_challenge/AlgonautsDS-features/developer_kit/stimulus_features/raw/'
fmri_dir = '/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/'
movies_train = ["friends-s01"]
# movies_train = ["friends-s01", "friends-s02", "friends-s03", "friends-s04", "friends-s05", "movie10-bourne", "movie10-figures", "movie10-life", "movie10-wolf"]
movies_val = ["friends-s06"]
modality = "all"  #@param ["visual", "audio", "language", "all"]

excluded_samples_start = 5  #@param {type:"slider", min:0, max:20, step:1}
excluded_samples_end = 5  #@param {type:"slider", min:0, max:20, step:1}
hrf_delay = 3  #@param {type:"slider", min:0, max:10, step:1}
stimulus_window = 5

subject = 1 #@param ["1", "2", "3", "5"] {type:"raw", allow-input: true}

train_ds = AlgonautsDataset(features_dir, fmri_dir, movies=movies_train, subject=subject, excluded_samples_start=excluded_samples_start, excluded_samples_end=excluded_samples_end, hrf_delay=hrf_delay, stimulus_window=stimulus_window)
val_ds = AlgonautsDataset(features_dir, fmri_dir, movies=movies_val, subject=subject, excluded_samples_start=excluded_samples_start, excluded_samples_end=excluded_samples_end, hrf_delay=hrf_delay, stimulus_window=stimulus_window)

stimuli_features loaded:  dict_keys(['visual', 'audio', 'language'])
fmri_data loaded
stimuli and fmri aligned
stimuli_features loaded:  dict_keys(['visual', 'audio', 'language'])
fmri_data loaded
stimuli and fmri aligned


In [7]:
print(train_ds[0]['audio'].shape)
print(train_ds[0]['video'].shape)
print(train_ds[0]['language'].shape)
print(train_ds[0]['fmri'].shape)
print(len(train_ds))
print(len(val_ds))


(5, 128)
(5, 8192)
(768,)
(1000,)
22787
22924


In [13]:
print(len(train_ds))
print(len(val_ds))

22787
22924


In [15]:
audio, video, language, fmri = train_ds[0]['audio'], train_ds[0]['video'], train_ds[0]['language'], train_ds[0]['fmri']
print(audio.shape)
print(video.shape)
print(language.shape)
print(fmri.shape)

torch.Size([1, 128])
torch.Size([250])
torch.Size([250])
torch.Size([1000])


In [None]:
batch_size = 4
num_workers = 4

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [4]:
print(train_ds.get_partition_indices())

defaultdict(<class 'list'>, {'friends-s01': (0, 22787)})


In [47]:
file = "/home/pranav/mihir/algonauts_challenge/algonauts_2025.competitors/fmri/sub-01/func/sub-01_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
with h5py.File(file, 'r') as f:
    key_arr = list(f.keys())
    data = []
    data_keys = []
    for key in sorted(key_arr):
        if 's01' in key:
            data_keys.append(key)
            data.append(f[key][:])
    data = np.concatenate(data, axis=0)
    # print(f['sub-01_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold'].shape)
print(sorted(key_arr))

['ses-001_task-s01e02a', 'ses-001_task-s01e02b', 'ses-001_task-s01e03a', 'ses-001_task-s01e03b', 'ses-002_task-s01e04a', 'ses-002_task-s01e04b', 'ses-002_task-s01e05a', 'ses-002_task-s01e05b', 'ses-003_task-s01e01a', 'ses-003_task-s01e01b', 'ses-003_task-s01e06a', 'ses-003_task-s01e06b', 'ses-004_task-s01e07a', 'ses-004_task-s01e07b', 'ses-004_task-s01e08a', 'ses-004_task-s01e08b', 'ses-004_task-s01e09a', 'ses-004_task-s01e09b', 'ses-005_task-s01e10a', 'ses-005_task-s01e10b', 'ses-005_task-s01e11a', 'ses-005_task-s01e11b', 'ses-006_task-s01e12a', 'ses-006_task-s01e12b', 'ses-006_task-s01e13a', 'ses-006_task-s01e13b', 'ses-006_task-s01e14a', 'ses-006_task-s01e14b', 'ses-007_task-s01e15a', 'ses-007_task-s01e15b', 'ses-007_task-s01e16a', 'ses-007_task-s01e16b', 'ses-007_task-s01e17a', 'ses-007_task-s01e17b', 'ses-008_task-s01e18a', 'ses-008_task-s01e18b', 'ses-008_task-s01e19a', 'ses-008_task-s01e19b', 'ses-009_task-s01e20a', 'ses-009_task-s01e20b', 'ses-009_task-s01e21a', 'ses-009_task-s

In [None]:
data.shape

In [None]:
print(sorted(data_keys))

In [None]:
num = 0
for i in train_ds:
    # print(i['fmri'].shape)
    # print(i['audio'].shape)
    # print(i['video'].shape)
    # print(i['language'].shape)
    num += 1
print(num)
# print(i['fmri'].shape)

In [5]:


class ResidualBlock(nn.Module):
   def __init__(self, in_dim, out_dim):
       super().__init__()
       self.downsample = in_dim != out_dim
       self.net = nn.Sequential(
           nn.Linear(in_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU(),
           nn.Linear(out_dim, out_dim),
           nn.LayerNorm(out_dim),
           nn.GELU()
       )
       if self.downsample:
           self.proj = nn.Linear(in_dim, out_dim)
   
   def forward(self, x):
       if self.downsample:
           return self.proj(x) + self.net(x)
       return x + self.net(x)

class Encoder(nn.Module):
   def __init__(self, input_dim=1000, hidden_dims=[512, 384, 256], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Initial projection with one residual block
       self.input_proj = ResidualBlock(input_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Project to token space with one residual block
       self.token_proj = ResidualBlock(hidden_dims[-1], num_tokens * codebook_dim)
       
       self.num_tokens = num_tokens
       self.codebook_dim = codebook_dim
       
   def forward(self, x):
       x = self.input_proj(x)
       x = self.layers(x)
       x = self.token_proj(x)
       return x.view(x.shape[0], self.num_tokens, self.codebook_dim)

class Decoder(nn.Module):
   def __init__(self, output_dim=1000, hidden_dims=[256, 384, 512], num_tokens=32, codebook_dim=64):
       super().__init__()
       
       # Process tokens with one residual block
       self.token_proj = ResidualBlock(num_tokens * codebook_dim, hidden_dims[0])
       
       # Main network with one residual block per layer
       layers = []
       for i in range(len(hidden_dims)-1):
           layers.append(ResidualBlock(hidden_dims[i], hidden_dims[i+1]))
       self.layers = nn.Sequential(*layers)
       
       # Final projection with one residual block
       self.output_proj = ResidualBlock(hidden_dims[-1], output_dim)
       
   def forward(self, x):
       # x shape: [batch_size, num_tokens, codebook_dim] 
       x = x.reshape(x.shape[0], -1)  # Flatten tokens
       x = self.token_proj(x)
       x = self.layers(x)
       return self.output_proj(x)

class VQVAE(L.LightningModule):
    def __init__(
            self, 
            input_dim=1000, 
            hidden_dims=[512, 384, 256], 
            num_tokens=32, 
            codebook_size=1024, 
            codebook_dim=8,
            commitment_weight=0.25,
            quantizer_decay=0.99,
            learning_rate=3e-4,
            weight_decay=0.01
            ):
        super().__init__()
        
        self.encoder = Encoder(input_dim, hidden_dims, num_tokens, codebook_dim)
        self.decoder = Decoder(input_dim, hidden_dims[::-1], num_tokens, codebook_dim)
        self.quantizer = VectorQuantize(
                dim=codebook_dim,
                codebook_size=codebook_size,
                decay=quantizer_decay,
                commitment_weight=commitment_weight
                )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()
        
    def forward(self, x):
        z = self.encoder(x)
        z_q, indices, commitment_loss = self.quantizer(z)
        x_recon = self.decoder(z_q)
        return x_recon, commitment_loss, indices
    
    def encode(self, x):
        z = self.encoder(x)
        _, indices, _ = self.quantizer(z)
        return indices
        
    def decode(self, indices):
        z_q = self.quantizer.get_codes_from_indices(indices)
        return self.decoder(z_q)
    
class MultiModalMLP(L.LightningModule):
    def __init__(self, video_dim, audio_dim, text_dim, hidden_dim, num_tokens, fmri_tok_dir, learning_rate, weight_decay):
        super().__init__()
        
        # Dimensions for each modality
        self.video_dim = video_dim 
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.fmri_tokenizer = VQVAE.load_from_checkpoint(fmri_tok_dir)
        
        for param in self.fmri_tokenizer.parameters():
            param.requires_grad = False
        self.fmri_tokenizer.eval()
        
        # Trainable token for missing text
        self.missing_text_token = nn.Parameter(torch.randn(1, text_dim))
        
        # MLP layers
        total_dim = video_dim + audio_dim + text_dim
        self.mlp = nn.Sequential(
            nn.Linear(total_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_tokens)
        )
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.save_hyperparameters()

    def forward(self, video, audio, text):
        """
        Args:
            video: [batch_size, video_dim]
            audio: [batch_size, audio_dim] 
            text: [batch_size, text_dim]
        """
        batch_size = video.shape[0]
        # Generate text mask based on NaN values
        text_mask = ~torch.isnan(text).any(dim=1, keepdim=True)
        
        # Replace missing text with learned token
        text = torch.where(
            text_mask,
            text,
            self.missing_text_token.expand(batch_size, -1)
        )
        
        # Concatenate all features
        text = text.unsqueeze(1)
        x = torch.cat([video, audio, text], dim=-1).squeeze(1)
        # Predict tokens
        return self.mlp(x)

    def calculate_metrics(self, x, x_recon):
        # Flatten the tensors for correlation calculation
        x_flat = x.reshape(x.shape[0], -1)
        x_recon_flat = x_recon.reshape(x_recon.shape[0], -1)
        
        # Calculate Pearson R for each sample in batch
        correlations = torch.stack([
            pearson_corrcoef(x_flat[i], x_recon_flat[i])
            for i in range(x_flat.shape[0])
        ])
        avg_pearson_r = correlations.mean()
        
        # Calculate variance explained
        total_variance = torch.var(x_flat, dim=1).sum()
        residual_variance = torch.var(x_flat - x_recon_flat, dim=1).sum()
        variance_explained = 1 - (residual_variance / total_variance)
        
        # Calculate MSE and MAE
        mse = F.mse_loss(x_flat, x_recon_flat)
        mae = F.l1_loss(x_flat, x_recon_flat)
        
        return avg_pearson_r, variance_explained, mse, mae
        

    def training_step(self, batch):
        video, audio, text, fmri = batch['video'], batch['audio'], batch['language'], batch['fmri']
        logits = self(video, audio, text)
        fmri_tokens = self.fmri_tokenizer.encode(fmri).to(dtype=logits.dtype)

        criterion = nn.MSELoss()
        loss = criterion(logits, fmri_tokens)
        recon_fmri = self.fmri_tokenizer.decode(logits.round().long())

        avg_pearson_r, variance_explained, mse, mae = self.calculate_metrics(fmri, recon_fmri)
        self.log('train_loss', loss)
        self.log('train_pearson_r', avg_pearson_r)
        self.log('train_variance_explained', variance_explained)
        self.log('train_mse', mse)
        self.log('train_mae', mae)
        return loss
    
    def validation_step(self, batch):
        video, audio, text, fmri = batch['video'], batch['audio'], batch['language'], batch['fmri']
        logits = self(video, audio, text)
        fmri_tokens = self.fmri_tokenizer.encode(fmri).to(dtype=logits.dtype)
        
        criterion = nn.MSELoss()
        val_loss = criterion(logits, fmri_tokens)
        val_recon_fmri = self.fmri_tokenizer.decode(logits.round().long())
        val_avg_pearson_r, val_variance_explained, val_mse, val_mae = self.calculate_metrics(fmri, val_recon_fmri)
        
        self.log('val_loss', val_loss)
        self.log('val_pearson_r', val_avg_pearson_r)
        self.log('val_variance_explained', val_variance_explained)
        self.log('val_mse', val_mse)
        self.log('val_mae', val_mae)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.learning_rate, 
            betas=(0.9, 0.999), 
            eps=1e-8, 
            weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

In [6]:
# Example usage:

batch_size = 4
video_dim = 8192
audio_dim = 128
text_dim = 768
hidden_dim = 1024
num_tokens = 32
learning_rate = 1e-4
weight_decay = 0.01 
num_workers = 4
fmri_tok_dir = "/home/pranav/mihir/algonauts_challenge/algonauts2025/checkpoints/fmri_vqvae_res_gelu_8qdim_1024/epoch=44_val_loss=0.137.ckpt"

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [7]:
print("Training samples: ", len(train_ds))
print("Validation samples: ", len(val_ds))

Training samples:  22787
Validation samples:  22924


In [8]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

run_name = "MM_MLP_1024dim_only_sub01"
wandb_logger = WandbLogger(
    project="algonauts2025",
    name=run_name,
    save_dir="wandb_logs/"
)
trainer = L.Trainer(
    max_epochs=50,
    precision=32,
    logger=wandb_logger,
    callbacks=[
        ModelCheckpoint(
            dirpath=f'checkpoints/{run_name}',
            filename='{epoch:02d}_{val_loss:.3f}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            save_last=True
        )
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
from lightning.pytorch.utilities.model_summary import ModelSummary
torch.set_float32_matmul_precision('high')
mm_mlp = MultiModalMLP(
    video_dim=video_dim,
    audio_dim=audio_dim,
    text_dim=text_dim,
    hidden_dim=hidden_dim,
    num_tokens=num_tokens,
    fmri_tok_dir=fmri_tok_dir,
    learning_rate=learning_rate,
    weight_decay=weight_decay
)
summary = ModelSummary(mm_mlp, max_depth=1)
print(summary)

  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | fmri_tokenizer | VQVAE      | 5.4 M  | eval 
1 | mlp            | Sequential | 10.4 M | train
  | other params   | n/a        | 768    | n/a  
------------------------------------------------------
10.4 M    Trainable params
5.4 M     Non-trainable params
15.8 M    Total params
63.157    Total estimated model params size (MB)
8         Modules in train mode
79        Modules in eval mode


In [10]:
trainer.fit(mm_mlp, train_dl, val_dl)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmihir-neal[0m ([33mmihirneal[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | fmri_tokenizer | VQVAE      | 5.4 M  | eval 
1 | mlp            | Sequential | 10.4 M | train
  | other params   | n/a        | 768    | n/a  
------------------------------------------------------
10.4 M    Trainable params
5.4 M     Non-trainable params
15.8 M    Total params
63.157    Total estimated model params size (MB)
8         Modules in train mode
79        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
