# Training Notebook


- Load in training set
- Load in config
<!-- - Plot a protein -->
- Train
  - Record loss over time
  - Pecord the predictioned PDB of a set of RNA's over the training procedure
  - Record the FAPE loss per example over time
  - Record model checkpoints of best parameters based on validation loss

- Load hold out set?
- Record performance on the hold out set?

In [None]:
from google.colab import drive
drive.mount('/content/drive',  force_remount=True)
import os, sys
folder_name = 'RNAProject' #@param {type:"string"}
os.chdir(os.getcwd()+'/drive/MyDrive/%s'%folder_name)

!pip -q install biopython ml-collections pytorch-lightning


from data.data_transforms import transform, prepare_features
from data.loader import RNA
# from data.build_dataset import open_data, RNAData
from config import model_config, optim_config, TRACK_CODES
from data.generate_fake_example import generate_fake_example
from modules import AlphaFold

import os, json, gzip, glob
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_info

from psutil import virtual_memory

seed_everything(42, workers=True)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Mounted at /content/drive


Global seed set to 42


Wed Sep  8 23:28:10 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    37W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
### only used in late versions
def RMSE_weighter(experiment, codes, end = -3):
  # end is the slice to the end to average over
  ''' note we can only look at the training performance! else we'd be cheating '''
  # load in the logs 
  f1 = gzip.open('lightning_logs/'+experiment+'/records.gz','rb')
  records = json.loads(gzip.decompress(f1.read()).decode("utf-8"))
  t_rec = records['train_metrics']#['train']

  drmse = {c:[[] for _ in range(len(t_rec))] for c in codes}
  for i,e in enumerate(t_rec):
    for bi,c,tm,rmsd,gdt,lddt in e:
      if rmsd is not None:
        drmse[c][i].append(rmsd)
  
  def get_weight(ep_rmse):
    a = [r for e in ep_rmse[end:] for r in e]
    if len(a)!=0:return sum(a)/len(a)
    a = [r for e in ep_rmse for r in e]
    if len(a)!=0:return a[-1]
    return None# example was never seen- or RMSE couldn't be computed

  drmse = [get_weight(drmse[c]) for c in codes]
  tot = [r for r in drmse if r is not None]
  av = sum(tot)/len(tot)
  drmse = [1 if w is None else w/av for w in drmse]
  return drmse# average weight is 1

In [None]:
def instantiateAF(config, device):
  num_recycle = torch.randint(1,config.data.common.num_recycle, (1,)).item()
  # ex = train_gen[0]# get an example
  # msa_emb = ex['msa_feat'].shape[-1]
  # targ_emb = ex['target_feat'].shape[-1]
  # n_token = ex['aatype'].max() + 1
  # this automatically generates a small example to pass through the model
  random_batch = generate_fake_example(num_recycle)

  # config.model.heads.masked_msa.num_output = msa_emb
  af = AlphaFold(config, True, True)
  out, loss, info = af({k:v.to(device) for k,v in random_batch.items()})
  del loss
  del out
  return af

def exp_decay_pmf(n_cat, decay_exponent, weights=None, device='cpu', clip=False):
  ''' generate a prob mass fn from binning [0, 1] and passing 
  through e^(-decay_exponent * x) then normalising '''
  if clip:
    x = torch.tensor([0]*5+[0.2,0.6,0.8], device=device)
  else:
    x = torch.linspace(0, 1, n_cat, device=device)
  px = torch.exp( - decay_exponent * x )
  if weights is not None: px *= weights
  return px / px.sum()

class RNAData(Dataset):
  ''' 
  - Many of the structures are poor resolution > 4.0A.
  - Use these structures at the start of training and less so 
    towards the end of training
  '''
  res = ['1.5A','2.0A','2.5A','3.0A','3.5A','4.0A','20.0A','above 20']
  def __init__(self, 
      data, config, prepare_features, transform, deterministic=True,
      resolution_bins=None, pmf=None, device='cpu', rate=0.5, 
      num_ep_to_fully_recycle=8, basis=None, basis_shifts=None, test=False, 
      start=0,filter_msa=None, rmse_weight=None, clip_pmf=False, rescale_size=1
    ):
    min_len = 5
    cd = config.data
    kw = {'rna':True, 'basis':basis, 'basis_shifts':basis_shifts}


    # go through data and filter MSA if e-val is to high
    if filter_msa is not None:
      msa_evals = json.loads(open('data/msa_evals.json','r').read())
      get_rid = {k for k, v in msa_evals.items() if v>filter_msa and k in data}
      print('filtered %d msa to %d msa with e-values better than %s'%(
          len(msa_evals), len(msa_evals)-len(get_rid), str(filter_msa)))
      for k in get_rid:
        d = data[k].__dict__
        d['msa'] = [d['sequence']]
        data[k] = RNA(**d)
      
    self.data = {k:prepare_features(v, device, **kw) for k,v in tqdm(data.items()) if v.num_res > min_len}
    self.ix2key = {i:k for i,k in enumerate(self.data)}

    if deterministic:
      if test:
        self.gett = self.test
      else:
        self.gett = self.deterministic
      self.effective_length = len(self.data)
    else:
      self.gett = self.randomised

      # if the epoch is greater than this then set num_recycle to its max
      self.num_ep_to_fully_recycle = num_ep_to_fully_recycle

      # resolution_bins is a partition of the training examples (codes)
      self.resolution_bins = [[b for b in resolution_bins[r] if b in self.data] for r in self.res]
      
      # get the pmf for each code in a resolution bin, based on length
      c = cd.training.constant.crop_size
      def get_weight(code):
        ''' goes from min_weight at 0 to 1 at short_len, then stays at 1 until 
        160 when it linearly increases wiith grad 1/160 '''
        l = self.data[code][0]['seq_length'][0].item()
        # longer sequences than crop size need to be sampled more..
        resampling_score = max(l, c) / c

        # short sequences are easier and don't make best use of parallelism
        min_weight = 0.1
        short_len = 20
        short = min_weight + (1 - min_weight) * min(short_len, l) / short_len
        # if l < 5: short = 0
        w = short * resampling_score
        if rmse_weight is None:return w
        return 0.5 * (c2w[code] + w)# average the RMSE weight with the length weight

      if rmse_weight is not None:
        codes = list(self.data.keys())
        c2w = dict(zip(codes, RMSE_weighter(rmse_weight, codes, end=-3)))

      resolution_seq_weight = [list(map(get_weight, r)) for r in self.resolution_bins]
      self.reso_seq_weights = [torch.tensor(r, device=device)/sum(r) for r in resolution_seq_weight]

      # get pmf for each resolution bin, based on all lengths in that bin
      self.reso_weights = torch.tensor([sum(r) for r in resolution_seq_weight], device=device)

      self.effective_length = int(self.reso_weights.sum().item() * rescale_size)
                          
      # self.res_counts = [len(r) for r in self.resolution_bins]
      self.pmf = pmf
      self.rate = rate
      self.epoch = start
      self.probs = pmf(len(self.res), start * self.rate, self.reso_weights, device, clip_pmf)
      self.clip_pmf = clip_pmf

    self.config = cd
    self.max_recycle = cd.common.num_recycle
    self.num_ensemble = cd.training.constant.num_ensemble
    self.transform = transform
    self.device = device

  def __len__(self):
    return self.effective_length
  
  def randomised(self, idx):
    if self.epoch < self.num_ep_to_fully_recycle:
      num_recycle = torch.randint(1,self.max_recycle, (1,)).item()
    else:
      num_recycle = self.max_recycle

    end = idx==len(self)-1

    if end:
      self.epoch += 1
      self.probs = self.pmf(len(self.res), self.epoch * self.rate, self.reso_weights, self.device, self.clip_pmf)
    
    bin_ = torch.multinomial(self.probs, 1).item()
    i = torch.multinomial(self.reso_seq_weights[bin_], 1).item()
    code = self.resolution_bins[bin_][i]

    batch = self.transform(*self.data[code], self.config, self.num_ensemble, 
                          num_recycle, rna=True)
    batch['code'] = code
    return batch

  def deterministic(self, idx):
    num_recycle = torch.randint(1,self.max_recycle, (1,)).item()
    code = self.ix2key[idx]
    batch = self.transform(*self.data[code], self.config, self.num_ensemble, 
                          num_recycle, rna=True)
    batch['code'] = code
    return batch
  
  def test(self, idx):
    num_recycle = self.max_recycle
    code = self.ix2key[idx]
    batch = self.transform(*self.data[code], self.config, self.num_ensemble, 
                          num_recycle, rna=True, test=True)
    batch['code'] = code
    return batch
  
  def __getitem__(self, idx):
    return self.gett(idx)
    
def collate_fn(data):
  # it's already batched... so only use the first item (only item)
  return data[0]

In [None]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

model_name = 'model_5'#@param {type:"string"}
load_params = ''#@param {type:"string"}
config_name = 'original'#@param {type:"string"}
num_epochs = "16"#@param {type:"string"}
LOG_FREQ = "200"#@param {type:"string"}
Jones_basis_suggestion = "True"### used to be a param
JONES_SUGGESTION = Jones_basis_suggestion.lower()=='true'
LOG_FREQ = int(LOG_FREQ)
config = model_config(model_name)
LOAD_PARAMS = load_params.lower()=='true'
checkpoint_version = '14'#@param {type:"string"}
shift_start = '-26'#@param {type:"string"}
evo_blocks = '8'#@param {type:"string"}
fold_layers = ''#@param {type:"string"}
clip_pmf_ = 'True'#@param {type:"string"}
filt_msa = '1e-6'#@param {type:"string"}
filt_msa = None if filt_msa=='' else float(filt_msa)
rmse_weight_experiment = ''#@param {type:"string"}
if rmse_weight_experiment=='': rmse_weight_experiment=None
override_lr = '5e-4'#@param {type:"string"}
epoch_size_scaler = '2.0'#@param {type:"string"}
epoch_size_scaler = float(epoch_size_scaler) if epoch_size_scaler!='' else 1

if checkpoint_version!='':
  path = 'lightning_logs/version_%s/checkpoints'%checkpoint_version
  [path] = list(glob.glob(path+'/*.ckpt'))
  LOAD_FROM_CHECKPOINT = path
  p = path[path.index('epoch=')+6:path.index('-step')]
  print(p)
  START = int(p) + int(shift_start)
  print('starting at epoch %d'%START)
else:
  LOAD_FROM_CHECKPOINT = None
  START = 0

opt_cfg = optim_config(config_name)
if override_lr!='':
  opt_cfg['optim_groups']['default']['lr'] = float(override_lr)

EVO_BLOCKS = int(evo_blocks) if evo_blocks!='' else None
FOLD_LAYERS = int(fold_layers) if fold_layers!='' else None

f = gzip.open('data/dataset.gz','rb')
loaded_rnas = json.loads(gzip.decompress(f.read()).decode("utf-8"))
all_data = {}
for dname, dat in loaded_rnas.items():
  all_data[dname] = {}
  for code,dictt in tqdm(dat.items()):
    all_data[dname][code] = RNA(**dictt)
    
resolution_bins = json.loads(open('data/resolution_partition.json', 'r').read())

training_set = all_data['train']
validation_set = all_data['validation']
holdout_set = all_data['hold-out']

## add config.trackcodes
# TRACK_CODES = CHAINS_TO_TRACK#temp#config.trackcodes
# TRACK_CODES_TRAIN_METRICS = CHAINS_FOR_TRAIN_METRICS

31
starting at epoch 5


100%|██████████| 845/845 [00:00<00:00, 215164.33it/s]
100%|██████████| 707/707 [00:00<00:00, 214975.56it/s]
100%|██████████| 2273/2273 [00:00<00:00, 182892.80it/s]


In [None]:
# ks = list(validation_set.keys())
# kk = {k for k in ks if k in resolution_bins['1.5A']}
# ll = [(k,validation_set[k].num_res) for k in kk]
# ll.sort(key=lambda x:x[1])
# print(ll[::-1])

### JUST USED FOR TESTING
# def subset(d, l):
#   ks = set(list(d.keys())[:l])
#   return {k:v for k,v in d.items() if k in ks}

# training_set = subset(training_set, 40)
# validation_set = subset(validation_set, 30)
# holdout_set = subset(holdout_set, 30)
# TRACK_CODES = [c for c in TRACK_CODES if c in validation_set]

In [None]:
config.model.global_config.device = device
num_epochs = int(num_epochs)

data_args = (config, prepare_features, transform)

if JONES_SUGGESTION:
  basis_kw = {'device':device, 'basis':("C4'",'P','P'), 'basis_shifts':(0,1,0)}
else:
  # my default basis
  basis_kw = {'device':device, 'basis':None, 'basis_shifts':None}
if filt_msa is not None: basis_kw['filter_msa'] = filt_msa
if rmse_weight_experiment is not None: basis_kw['rmse_weight'] = rmse_weight_experiment
if clip_pmf_.lower()=='true': basis_kw['clip_pmf'] = True

train_gen = RNAData(
    training_set, *data_args, deterministic=False, resolution_bins=resolution_bins, 
    pmf=exp_decay_pmf, num_ep_to_fully_recycle=num_epochs, start=START, 
    rescale_size=epoch_size_scaler,
    **basis_kw
)
valid_gen = RNAData(validation_set, *data_args, deterministic=True, **basis_kw)
holdout_gen = RNAData(holdout_set, *data_args, deterministic=True, **basis_kw)

filtered 875 msa to 776 msa with e-values better than 1e-06


100%|██████████| 2273/2273 [01:10<00:00, 32.39it/s]


filtered 875 msa to 844 msa with e-values better than 1e-06


100%|██████████| 707/707 [00:42<00:00, 16.56it/s] 


filtered 875 msa to 862 msa with e-values better than 1e-06


100%|██████████| 845/845 [00:36<00:00, 23.14it/s] 


In [None]:
if EVO_BLOCKS is not None:
  config.model.embeddings_and_evoformer.evoformer_num_block = EVO_BLOCKS
if FOLD_LAYERS is not None:
  config.model.heads.structure_module.num_layer = FOLD_LAYERS

af = instantiateAF(config, device)

if LOAD_PARAMS:
  af.load_state_dict(torch.load('params/torch_'+model_name))
  
af.eval()
af.is_training = True
af.compute_loss = True

### SET UP THE PARAM GROUPS
af.set_optim_config(opt_cfg, num_epochs)

### SET UP LOGGING AND METRICS
af.track_codes(TRACK_CODES, LOG_FREQ, rec_t_coords_every=4, train_coord_buffer_size=30)

af = af.to(device)

data_conf = {
    'collate_fn':collate_fn,
    'shuffle':True,
    'batch_size':1,
}# I think lightning sorts the rest out... 'num_workers':4, 'pin_memory':True

train_loader = DataLoader(train_gen, **data_conf)

data_conf['shuffle'] = False

val_loader = DataLoader(valid_gen, **data_conf)
holdout_loader = DataLoader(holdout_gen, **data_conf)

train_conf = {
    'auto_lr_find':False,
    'progress_bar_refresh_rate':1, 
    'max_epochs':num_epochs, 
    'gradient_clip_val':0.1, 
    # 'plugins':DeepSpeedPlugin(
    #     stage=3,
    #     cpu_offload=True,  # Enable CPU Offloading
    #     partition_activations=True,
    #     # cpu_checkpointing=True,  # (Optional) offload activations to CPU
    # ),
}
if torch.cuda.is_available(): 
  train_conf['gpus'] = 1
  # train_conf['profiler'] = 'simple'

trainer = Trainer(**train_conf)


if LOAD_FROM_CHECKPOINT is not None:
  print('loading from checkpoint')
  af.load_state_dict(torch.load(LOAD_FROM_CHECKPOINT)["state_dict"])
  print('done')

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(other, self)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


no scheduler found
loading from checkpoint
done


In [8]:
trainer.fit(af, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | alphafold_iteration | AlphaFoldIteration | 19.5 M
-----------------------------------------------------------
19.5 M    Trainable params
0         Non-trainable params
19.5 M    Total params
77.834    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 22 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 26 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 10 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 11 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 8 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 3 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 4 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 1 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 1 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


Validating: 0it [00:00, ?it/s]

1ML5-B failed to record metrics in validation step
4ADX-9 failed to record metrics in validation step
1E8S-C failed to record metrics in validation step
2R1G-E failed to record metrics in validation step
3EP2-C failed to record metrics in validation step
2R1G-X failed to record metrics in validation step
1X18-A failed to record metrics in validation step
1MJ1-R failed to record metrics in validation step
1MVR-E failed to record metrics in validation step
4V5Z-BB failed to record metrics in validation step
1TRJ-C failed to record metrics in validation step
1QZC-C failed to record metrics in validation step
4V5Z-BM failed to record metrics in validation step
Failed to record 0 codes metrics in training step


In [None]:
af.tidy_data()

Tidying logs... Done.


In [None]:
trainer.test(test_dataloaders=holdout_loader, ckpt_path="best")

# trainer.test(ckpt_path="/path/to/my_checkpoint.ckpt")
# trainer.test(model)

  "`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]



3EQ4-Y failed to record metrics in test step
2AGN-A failed to record metrics in test step
2R1G-D failed to record metrics in test step
486D-G failed to record metrics in test step
4V5Z-AE failed to record metrics in test step
4V5Z-AB failed to record metrics in test step
4V5Z-BL failed to record metrics in test step
1QZC-B failed to record metrics in test step
1X18-C failed to record metrics in test step
4V5Z-BW failed to record metrics in test step
4V5Z-BD failed to record metrics in test step
1ZC8-B failed to record metrics in test step
4V5Z-BU failed to record metrics in test step
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 3.764826536178589}
--------------------------------------------------------------------------------


[{'test_loss': 3.764826536178589}]