# Visible Neural Network

In [None]:
# Data ----
from dataG2F.core import get_data
from dataG2F.qol  import ensure_dir_path_exists

# Data Utilities ----
import numpy  as np
import pandas as pd

# Model Building  ----
## General ====
import torch
from   torch import nn
import torch.nn.functional as F
from   torch.utils.data import Dataset
from   torch.utils.data import DataLoader

from vnnpaper.zma import \
    BigDataset,    \
    plDNN_general, \
    mask_parents,  \
    vnn_factory_1, \
    vnn_factory_2, \
    vnn_factory_3

# Hyperparameter Tuning ----
import os # needed for checking history (saved by lightning) 

## Logging with Pytorch Lightning ====
import lightning.pytorch as pl
from   lightning.pytorch.loggers import CSVLogger # used to save the history of each trial (used by ax)

## Adaptive Experimentation Platform ====
from ax.service.ax_client import AxClient, ObjectiveProperties
# from ax.utils.notebook.plotting import init_notebook_plotting, render

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
# init_notebook_plotting()

## Setup

In [None]:
cache_path = '../nbs_artifacts/zma_g2f_multitarget_eres_0hyps/'

In [None]:
## Settings ====
run_hyps = 32 
run_hyps_force = False # should we run more trials even if the target number has been reached?
max_hyps = 64

# Run settings: 
params_run = {
    'batch_size': 256,
    'max_epoch' : 32,    
}

# data settings
params_data = {
    # 'y_var': 'Yield_Mg_ha',
    'y_var': [
        # Description quoted from competition data readme
        'Yield_Mg_ha',     # Grain yield in Mg per ha at 15.5% grain moisture, using plot area without alley (Mg/ha).
        'Pollen_DAP_days', # Number of days after planting that 50% of plants in the plot began shedding pollen.
        'Silk_DAP_days',   # Number of days after planting that 50% of plants in the plot had visible silks.
        'Plant_Height_cm', # Measured as the distance between the base of a plant and the ligule of the flag leaf (centimeter).
        'Ear_Height_cm',   # Measured as the distance from the ground to the primary ear bearing node (centimeter).
        'Grain_Moisture',  # Water content in grain at harvest (percentage).
        'Twt_kg_m3'        # Shelled grain test weight (kg/m3), a measure of grain density.
    ],

    'y_resid': 'Env', # None, Env, Geno
    'y_resid_strat': 'naive_mean', # None, naive_mean, filter_mean, ...
    'holdout_parents': [
        ## 2022 ====
        'LH244',
        ## 2021 ====
        'PHZ51',
        # 'PHP02',
        # 'PHK76',
        ## 2019 ====
        # 'PHT69',
        'LH195',
        ## 2017 ====
        # 'PHW52',
        # 'PHN82',
        ## 2016 ====
        # 'DK3IIH6',
        ## 2015 ====
        # 'PHB47',
        # 'LH82',
        ## 2014 ====
        # 'LH198',
        # 'LH185',
        # 'PB80',
        # 'CG102',
 ],    
}

In [None]:
params_list = [    
    ## Output Size ====
    {
    'name': 'default_out_nodes_inp',
    'type': 'range',
    'bounds': [1, 8],
    'value_type': 'int',
    'log_scale': False
    },
    {
    'name': 'default_out_nodes_edge',
    'type': 'range',
    'bounds': [1, 32],
    'value_type': 'int',
    'log_scale': False
    },
    {
    'name': 'default_out_nodes_out',
    'type': 'fixed',
    'value': len(params_data['y_var']) if type(params_data['y_var']) == list else 1,
    'value_type': 'int',
    'log_scale': False
    },
    ## Dropout ====
    {
    'name': 'default_drop_nodes_inp',
    'type': 'range',
    'bounds': [0.01, 0.99],
    'value_type': 'float',
    'log_scale': False
    },
    {
    'name': 'default_drop_nodes_edge',
    'type': 'range',
    'bounds': [0.01, 0.99],
    'value_type': 'float',
    'log_scale': False
    },
    {
    'name': 'default_drop_nodes_out',
    'type': 'range',
    'bounds': [0.01, 0.99],
    'value_type': 'float',
    'log_scale': False,
    'sort_values':True
    },
    ## Node Repeats ====
    {
    'name': 'default_reps_nodes_inp',
    'type': 'choice',
    'values': [1, 2, 3],
    'value_type': 'int',
    'is_ordered': True,
    'sort_values':True
    },
    {
    'name': 'default_reps_nodes_edge',
    'type': 'choice',
    'values': [1, 2, 3],
    'value_type': 'int',
    'is_ordered': True,
    'sort_values':True
    },
    {
    'name': 'default_reps_nodes_out',
    'type': 'choice',
    'values': [1, 2, 3],
    'value_type': 'int',
    'is_ordered': True,
    'sort_values':True
    },
    ## Node Output Size Scaling ====
    {
    'name': 'default_decay_rate',
    'type': 'choice',
    'values': [0+(0.1*i) for i in range(10)]+[1.+(1*i) for i in range(11)],
    'value_type': 'float',
    'is_ordered': True,
    'sort_values':True
    }
    ]

In [None]:
lightning_log_dir = cache_path+"lightning"
exp_name = [e for e in cache_path.split('/') if e != ''][-1]

In [None]:
# parameterization is needed for setup. These values will be overwritten by Ax if tuning is occuring. 
# in this file I define params later. I've included it here to gurantee that we can merge other params dicts into it.
params = {
'default_out_nodes_inp'  : 1,
'default_out_nodes_edge' : 1,
'default_out_nodes_out'  : len(params_data['y_var']) if type(params_data['y_var']) == list else 1,

'default_drop_nodes_inp' : 0.0,
'default_drop_nodes_edge': 0.0,
'default_drop_nodes_out' : 0.0,

'default_reps_nodes_inp' : 1,
'default_reps_nodes_edge': 1,
'default_reps_nodes_out' : 1,

'default_decay_rate'     : 1
}

default_out_nodes_inp  = params['default_out_nodes_inp' ]
default_out_nodes_edge = params['default_out_nodes_edge'] 
default_out_nodes_out  = params['default_out_nodes_out' ]

default_drop_nodes_inp = params['default_drop_nodes_inp' ] 
default_drop_nodes_edge= params['default_drop_nodes_edge'] 
default_drop_nodes_out = params['default_drop_nodes_out' ] 

default_reps_nodes_inp = params['default_reps_nodes_inp' ]
default_reps_nodes_edge= params['default_reps_nodes_edge']
default_reps_nodes_out = params['default_reps_nodes_out' ]

default_decay_rate = params['default_decay_rate' ]

In [None]:
batch_size = params_run['batch_size']
max_epoch  = params_run['max_epoch']

y_var = params_data['y_var']

In [None]:
save_prefix = [e for e in cache_path.split('/') if e != ''][-1]

if 'None' != params_data['y_resid_strat']:
    save_prefix = save_prefix+'_'+params_data['y_resid_strat']

ensure_dir_path_exists(dir_path = cache_path)

In [None]:
use_gpu_num = 0

device = "cuda" if torch.cuda.is_available() else "cpu"
if use_gpu_num in [0, 1]: 
    torch.cuda.set_device(use_gpu_num)
print(f"Using {device} device")

Using cuda device


## Load Data

In [None]:
# Data Prep ----
obs_geno_lookup          = get_data('obs_geno_lookup')
phno                     = get_data('phno')
ACGT_gene_slice_list     = get_data('KEGG_slices')
parsed_kegg_gene_entries = get_data('KEGG_entries')

In [None]:
# make sure that the given y variable is there
# single column version
# phno = phno.loc[(phno[y_var].notna()), ].copy()
# phno = phno.reset_index().drop(columns='index')

# multicolumn
# mask based on the y variables
na_array = phno[y_var].isna().to_numpy().sum(axis=1)
mask_no_na = list(0 == na_array)

phno = phno.loc[mask_no_na, ].copy()
phno = phno.reset_index().drop(columns='index')

In [None]:
# update obs_geno_lookup

tmp = phno.reset_index().rename(columns={'index': 'Phno_Idx_new'}).loc[:, ['Phno_Idx_new', 'Geno_Idx']]
tmp = pd.merge(tmp,
          tmp.drop(columns='Phno_Idx_new').drop_duplicates().reset_index().rename(columns={'index': 'Phno_Idx_Orig_new'}))
tmp = tmp.sort_values('Phno_Idx_new').reset_index(drop=True)

obs_geno_lookup = tmp.to_numpy()

In [None]:
# make holdout sets
holdout_parents = params_data['holdout_parents']

# create a mask for parent genotype
mask = mask_parents(df= phno, col_name= 'Hybrid', holdout_parents= holdout_parents)

train_mask = mask.sum(axis=1) == 0
test_mask  = mask.sum(axis=1) > 0

train_idx = train_mask.loc[train_mask].index
test_idx  = test_mask.loc[test_mask].index

In [None]:
# convert y to residual if needed

if params_data['y_resid'] == 'None':
    pass
else:
    if params_data['y_resid_strat'] == 'naive_mean':
        # use only data in the training set (especially since testers will be more likely to be found across envs)
        # get enviromental means, subtract from observed value
        for i in range(len(y_var)):
            tmp = phno.loc[train_idx, ]
            env_mean = tmp.groupby(['Env_Idx']
                        ).agg(Env_Mean = (y_var[i], 'mean')
                        ).reset_index()
            tmp = phno.merge(env_mean)
            tmp.loc[:, y_var[i]] = tmp.loc[:, y_var[i]] - tmp.loc[:, 'Env_Mean']
            phno = tmp.drop(columns='Env_Mean')

    if params_data['y_resid_strat'] == 'filter_mean':
        # for adjusting to environment we could use _all_ observations but ideally we will use the same set of genotypes across all observations
        def minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']],
                                    year = 2014):
            # Within each year what hybrids are most common?
            tmp = tmp.loc[(tmp.Year == year), ].groupby(['Env', 'Hybrid']).count().reset_index().sort_values('Year')

            all_envs = set(tmp.Env)
            # if we filter on the number of sites a hybrid is planted at, what is the largest number of sites we can ask for before we lose a location?
            # site counts for sets which contain all envs
            i = max([i for i in list(set(tmp.Year)) if len(set(tmp.loc[(tmp.Year >= i), 'Env'])) == len(all_envs)])

            before = len(set(tmp.loc[:, 'Hybrid']))
            after  = len(set(tmp.loc[(tmp.Year >= i), 'Hybrid']))
            print(f'Reducing {year} hybrids from {before} to {after} ({round(100*after/before)}%).')
            tmp = tmp.loc[(tmp.Year >= i), ['Env', 'Hybrid']].reset_index(drop=True)
            return tmp

        for i in range(len(y_var)):
            tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']]
            filter_hybrids = [minimum_hybrids_for_env(tmp = phno.loc[:, ['Env', 'Year', 'Hybrid']], year = i) 
                            for i in list(set(phno.Year)) ]
            env_mean = pd.concat(filter_hybrids).merge(phno, how = 'left')

            env_mean = env_mean.groupby(['Env_Idx']
                            ).agg(Env_Mean = (y_var[i], 'mean')
                            ).reset_index()

            tmp = phno.merge(env_mean)
            tmp.loc[:, y_var[i]] = tmp.loc[:, y_var[i]] - tmp.loc[:, 'Env_Mean']
            phno = tmp.drop(columns='Env_Mean')
            

In [None]:
# center and y value data
assert 0 == phno.loc[:, y_var].isna().sum().sum() # second sum is for multiple y_vars

y = phno.loc[:, y_var].to_numpy() # added to make multiple ys work
# use train index to prevent information leakage
y_c = y[train_idx].mean(axis=0)
y_s = y[train_idx].std(axis=0)

y = (y - y_c)/y_s

## Fit Using VNNHelper

In [None]:
myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = params, ACGT_gene_slice_list = ACGT_gene_slice_list)

################################################################################
{'default_out_nodes_inp': 1, 'default_out_nodes_edge': 1, 'default_out_nodes_out': 7, 'default_drop_nodes_inp': 0.0, 'default_drop_nodes_edge': 0.0, 'default_drop_nodes_out': 0.0, 'default_reps_nodes_inp': 1, 'default_reps_nodes_edge': 1, 'default_reps_nodes_out': 1, 'default_decay_rate': 1}
################################################################################
Retaining 43.53%, 6067/13939 Entries
Removed node "Others"


### Calculate nodes membership in each matrix and positions within each

In [None]:
### Creating Structured Matrices for Layers
M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)

### Setup Dataloader using `M_list`

In [None]:
lookup_dict = new_lookup_dict

vals = get_data('KEGG_slices')
vals = [torch.from_numpy(e).to(torch.float) for e in vals]
# restrict to the tensors that will be used
vals = torch.concat([vals[lookup_dict[i]].reshape(4926, -1) 
                     for i in M_list[0].row_inp
                    #  for i in dd[0]['inp'] # matches
                     ], axis = 1)

vals = vals.to('cuda')

In [None]:
training_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = False,
    lookup_obs  = torch.from_numpy(np.array(train_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
    lookup_geno = torch.from_numpy(obs_geno_lookup),
    y =           torch.from_numpy(y).to(torch.float32)[:, None].squeeze(),
    G =           vals,
    G_type = 'raw',
    send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = True 
)

validation_dataloader = DataLoader(BigDataset(
    lookups_are_filtered = False,
    lookup_obs  = torch.from_numpy(np.array(test_idx)), #X.get('val:train',       ops_string='   asarray from_numpy      '),
    lookup_geno = torch.from_numpy(obs_geno_lookup),
    y =           torch.from_numpy(y).to(torch.float32)[:, None].squeeze(),
    G =           vals,
    G_type = 'raw',
    send_batch_to_gpu = 'cuda:0'
    ),
    batch_size = batch_size,
    shuffle = False 
)

## Structured Layer

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, layer_list):
        super(NeuralNetwork, self).__init__()
        self.layer_list = nn.ModuleList(layer_list)
 
    def forward(self, x):
        for l in self.layer_list:
            x = l(x)
        return x

## Tiny Test Study

In [None]:
def evaluate(parameterization):
    myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = parameterization, ACGT_gene_slice_list = ACGT_gene_slice_list)
    M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)
    layer_list =  vnn_factory_3(M_list = M_list)
    model = NeuralNetwork(layer_list = layer_list)
    
    VNN = plDNN_general(model)  
    optimizer = VNN.configure_optimizers()
    logger = CSVLogger(lightning_log_dir, name=exp_name)
    logger.log_hyperparams(params={
        'params': parameterization
    })

    trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)
    trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)


    # if we were optimizing number of training epochs this would be an effective loss to use.
    # trainer.callback_metrics['train_loss']
    # float(trainer.callback_metrics['train_loss'])

    # To potentially _overtrain_ models and still let the selction be based on their best possible performance,
    # I'll use the lowest average error in an epoch
    log_path = lightning_log_dir+'/'+exp_name
    fls = os.listdir(log_path)
    nums = [int(e.split('_')[-1]) for e in fls] 

    M = pd.read_csv(log_path+f"/version_{max(nums)}/metrics.csv")
    M = M.loc[:, ['epoch', 'train_loss']].dropna()

    M = M.groupby('epoch').agg(
        train_loss = ('train_loss', 'mean'),
        train_loss_sd = ('train_loss', 'std'),
        ).reset_index()

    train_metric = M.train_loss.min()
    print(train_metric)
    return {"train_loss": (train_metric, 0.0)}


In [None]:
## Generated variables ====
json_path = f"./{lightning_log_dir}/{exp_name}.json"

loaded_json = False
if os.path.exists(json_path): 
    ax_client = (AxClient.load_from_json_file(filepath = json_path))
    loaded_json = True

else:
    ax_client = AxClient()
    ax_client.create_experiment(
        name=exp_name,
        parameters=params_list,
        objectives={"train_loss": ObjectiveProperties(minimize=True)}
    )

run_trials_bool = True
if run_hyps_force == False:
    if loaded_json: 
        # check if we've reached the max number of hyperparamters combinations to test
        if max_hyps <= (ax_client.generation_strategy.trials_as_df.index.max()+1):
            run_trials_bool = False

if run_trials_bool:
    # run the trials
    for i in range(run_hyps):
        parameterization, trial_index = ax_client.get_next_trial()
        # Local evaluation here can be replaced with deployment to external system.
        ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameterization))

    ax_client.save_to_json_file(filepath = json_path)

  return _class(
  return _class(
  return _class(
  return _class(
[INFO 05-29 15:33:54] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 05-29 15:33:54] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.
  return cls(df=pd.concat(dfs, axis=0, sort=True))
[INFO 05-29 15:33:54] ax.service.ax_client: Generated new trial 6 with parameters {'default_out_nodes_inp': 8, 'default_out_nodes_edge': 31, 'default_drop_nodes_inp': 0.588232, 'default_drop_nodes_edge': 0.399651, 'default_drop_nodes_out': 0.143371, 'default_reps_nodes_inp': 2, 'default_reps_nodes_edge': 1, 'default_reps_nodes_out': 3, 'default_decay_rate': 2.0, 'default_out_nodes_out': 7} using model Sobol.


################################################################################
{'default_out_nodes_inp': 8, 'default_out_nodes_edge': 31, 'default_drop_nodes_inp': 0.588231707829982, 'default_drop_nodes_edge': 0.39965105094015596, 'default_drop_nodes_out': 0.14337116550654175, 'default_reps_nodes_inp': 2, 'default_reps_nodes_edge': 1, 'default_reps_nodes_out': 3, 'default_decay_rate': 2.0, 'default_out_nodes_out': 7}
################################################################################
Retaining 43.53%, 6067/13939 Entries
Removed node "Others"


/home/kickd/miniconda3/envs/fastai/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/kickd/miniconda3/envs/fastai/lib/python3.11/si ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type          | Params
---------------------------------------
0 | mod  | NeuralNetwork | 989 M 
---------------------------------------
295 K     Trainable params
989 M     Non-trainable params
989 M     Total params
3,959.301 Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/kickd/miniconda3/envs/fastai/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
  loss = F.mse_loss(pred, y_i)
/home/kickd/miniconda3/envs/fastai/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

  loss = F.mse_loss(pred, y_i)
  loss = F.mse_loss(pred, y_i)


Validation: |          | 0/? [00:00<?, ?it/s]

  loss = F.mse_loss(pred, y_i)


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.
[INFO 05-29 15:34:59] ax.service.ax_client: Completed trial 6 with data: {'train_loss': (0.975791, 0.0)}.
[INFO 05-29 15:34:59] ax.service.ax_client: Saved JSON-serialized state of optimization to `./../nbs_artifacts/zma_g2f_yhat_02/lightning/zma_g2f_yhat_02.json`.


0.9757907986640929


In [None]:
# parameterization = ax_client.get_best_parameters()[0]

# myvnn, new_lookup_dict = vnn_factory_1(parsed_kegg_gene_entries = parsed_kegg_gene_entries, params = parameterization, ACGT_gene_slice_list = ACGT_gene_slice_list)
# M_list = vnn_factory_2(vnn_helper = myvnn, node_to_inp_num_dict = new_lookup_dict)
# layer_list =  vnn_factory_3(M_list = M_list)
# model = NeuralNetwork(layer_list = layer_list)

# VNN = plDNN_general(model)  
# optimizer = VNN.configure_optimizers()
# # logger = CSVLogger(lightning_log_dir, name=exp_name)
# # logger.log_hyperparams(params={
# #     'params': parameterization
# # })

# # trainer = pl.Trainer(max_epochs=max_epoch, logger=logger)
# trainer = pl.Trainer(max_epochs=max_epoch)
# trainer.fit(model=VNN, train_dataloaders=training_dataloader, val_dataloaders=validation_dataloader)

# model = VNN.mod
# model.to('cuda')(next(iter(training_dataloader))[1]).shape

  return cls(df=pd.concat(dfs, axis=0, sort=True))


################################################################################
{'default_out_nodes_inp': 2, 'default_out_nodes_edge': 26, 'default_drop_nodes_inp': 0.780218251850456, 'default_drop_nodes_edge': 0.6004913296550513, 'default_drop_nodes_out': 0.5098963896185158, 'default_reps_nodes_inp': 3, 'default_reps_nodes_edge': 3, 'default_reps_nodes_out': 2, 'default_decay_rate': 0.9, 'default_out_nodes_out': 7}
################################################################################
Retaining 43.53%, 6067/13939 Entries
Removed node "Others"


/home/kickd/miniconda3/envs/fastai/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/kickd/miniconda3/envs/fastai/lib/python3.11/si ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type          | Params
---------------------------------------
0 | mod  | NeuralNetwork | 1.4 B 
---------------------------------------
727 K     Trainable params
1.4 B     Non-trainable params
1.4 B     Total params
5,708.917 Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/kickd/miniconda3/envs/fastai/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/kickd/miniconda3/envs/fastai/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


torch.Size([256, 7])