In [1]:
from pathlib import Path
import numpy as np
import pandas as pd
import os
import sys
import torch

sys.path.append('/hpc/home/ma618/cxrgen/')
sys.path.append('/hpc/home/ma618/cxrgen/src/')

In [2]:
from src.data import dataloaders

In [3]:
root = Path('/hpc/group/kamaleswaranlab/EmoryDataset/Images/chest_xrays')
dataset = root / 'longitudinal_data_corrected'
embedding_path = dataset / 'image_embeddings'
ehr_path = dataset / 'ehr_matrices'

In [4]:
supertable_path = root / 'matched_supertables_with_images'
sups = list(supertable_path.glob('*.pickle'))
len(sups)

244895

In [5]:
df = pd.read_pickle(sups[0])
df

Unnamed: 0,temperature,daily_weight_kg,height_cm,sbp_line,dbp_line,map_line,sbp_cuff,dbp_cuff,map_cuff,pulse,...,Albumin 5%_dose,infection,sepsis,CXR_ACC_NUM_1,vent_mode,vent_rate_set,vent_tidal_rate_set,vent_tidal_rate_exhaled,peep,vent_fio2
2016-05-30 07:29:53,36.6,60.0,,,,,72.000000,36.000000,,75.000000,...,,0,0,,AC/CMV Volume,12,450,402,5,0.50
2016-05-30 08:29:53,36.6,54.1,188.0,,,,134.000000,62.000000,,53.000000,...,,1,1,,,,,,,
2016-05-30 09:29:53,36.0,,,,,,107.333333,55.000000,74.666667,61.666667,...,,1,1,,AC/CMV Volume,15,450,453,5,0.4
2016-05-30 10:29:53,,,,,,,109.333333,57.333333,79.666667,59.000000,...,,1,1,,,,,,,
2016-05-30 11:29:53,,,,,,,114.000000,57.666667,79.666667,65.000000,...,,1,1,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2016-06-14 09:29:53,37.7,,,,,,122.000000,58.000000,83.000000,94.000000,...,,1,1,00004DX160001032,,,,,,
2016-06-14 10:29:53,,,,,,,,,,88.000000,...,,1,1,00004DX160001032,,,,,,
2016-06-14 11:29:53,37.8,,,,,,131.000000,66.000000,92.000000,92.000000,...,,1,1,00004DX160001032,AC/CMV Volume,20,,,5,0.3
2016-06-14 12:29:53,,,,,,,,,,91.500000,...,,1,1,00004DX160001032,,,,,,


In [6]:
encounter_paths = list(ehr_path.glob("*.npy"))
prev_cxr_paths = list(embedding_path.glob("*_ffill_embeddings.npy"))
target_paths = list(embedding_path.glob("*_interpolated_embeddings.npy"))
len(encounter_paths), len(prev_cxr_paths), len(target_paths)

(17690, 17690, 17690)

In [7]:
encounter_paths.sort()
prev_cxr_paths.sort()
target_paths.sort()

In [8]:
batch_size = 8
model_type = 'transformer'
max_seq_length = 100
num_workers  = 8
shuffle = True


In [9]:
dataloader = dataloaders.create_encounter_dataloaders(encounter_paths=encounter_paths, 
                                        prev_cxr_paths=prev_cxr_paths,
                                        target_paths=target_paths,
                                        batch_size=batch_size,
                                        model_type=model_type,
                                        max_seq_length=max_seq_length,
                                        num_workers=num_workers,
                                        shuffle=shuffle)

In [10]:
batch = next(iter(dataloader))

In [11]:
batch['ehr'].shape, batch['prev_cxr'].shape, batch['target'].shape, batch['attention_mask'].shape

(torch.Size([8, 100, 81]),
 torch.Size([8, 100, 512]),
 torch.Size([8, 100, 512]),
 torch.Size([8, 100]))

In [12]:
import importlib
import src
import src.models.transformer as transformer
import src.models.transformernn as transformernn

from tqdm import tqdm
from src.training.trainer import TrainerConfig

In [13]:
import src.configs.config_def as cf
cf.set_paths()
configs = cf.run_configs()

In [14]:
importlib.reload(src.models.transformernn)

<module 'src.models.transformernn' from '/hpc/home/ma618/cxrgen/src/models/transformernn.py'>

In [24]:
config = {
        'ehr_dim': 81,
        'cxr_dim': 512,
        'd_model': 512,
        'num_encoder_layers': 6,
        'num_decoder_layers': 6,
        'num_heads': 8,
        'mlp_ratio': 4.0,
        'dropout': 0.1,
        'max_seq_length': 100
    }
    
# Create model
model = transformernn.create_transformer_model(config)

In [25]:
inf_mask = torch.isinf(batch['ehr'])
inf_indices = torch.where(inf_mask)
num_infs = inf_mask.sum().item()

if inf_mask.any():
    # For a multi-dimensional tensor
    for idx in zip(*torch.where(inf_mask)):
        print(f"Inf found at index: {idx}")

In [26]:
predictions = model(batch['ehr'],  batch['prev_cxr'], None, batch['attention_mask'], batch['attention_mask'], debug = True, batch_idx_debug = 4)

EHR input
tensor([[0.4167, 0.6767, 0.2889,  ..., 0.0000, 0.0000, 0.0000],
        [0.5000, 0.6767, 0.2667,  ..., 0.0000, 0.0000, 0.0000],
        [0.8000, 0.6767, 0.4667,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
ehr - min: -6.5, max: 531.0, has_nan: False
EHR embedding
tensor([[ 1.1481e-01,  9.9440e-02, -2.7472e-02,  ...,  4.8887e-01,
          5.6493e-02, -2.2732e-01],
        [ 1.2922e-01,  3.0371e-02,  9.9204e-05,  ...,  4.7830e-01,
          7.8608e-02, -2.5675e-01],
        [ 1.1892e-01,  1.2091e-01,  5.4149e-02,  ...,  4.7943e-01,
          6.4979e-02, -3.3386e-01],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
      



In [27]:
predictions[4,:,:]

tensor([[ 0.0310,  0.4832, -0.2951,  ...,  1.2988, -0.2955, -0.4847],
        [ 0.2085,  0.5930, -0.3348,  ...,  0.7263, -0.0880, -0.3932],
        [ 0.1833,  0.3796, -0.3958,  ...,  0.9673, -0.2616, -0.2159],
        ...,
        [ 0.1569,  0.3412,  0.1895,  ...,  1.0137, -0.1192, -0.2396],
        [-0.0970,  0.6146,  0.1632,  ...,  0.7509, -0.0865, -0.1076],
        [ 0.0771,  0.6744, -0.0240,  ...,  0.6644, -0.0337, -0.1764]],
       grad_fn=<SliceBackward0>)

In [28]:
trainer_config = TrainerConfig(
        max_epochs=configs['num_epochs'],
        batch_size=configs['batch_size'],
        learning_rate=configs['lr'],
        weight_decay=0.01,
        lr_scheduler="cosine",
        warmup_ratio=0.1,
        checkpoint_dir= './',
        save_every=5,
        patience=15,
        grad_norm_clip=1.0,
        log_every=10,
        eval_every=1,
        mixed_precision=True,
        teacher_forcing_ratio=0.5,
        teacher_forcing_decay=0.98,  # Gradually reduce teacher forcing
    )

In [29]:
device = 'cpu'
        
# Create checkpoint directory

# Initialize teacher forcing ratio
teacher_forcing_ratio = trainer_config.teacher_forcing_ratio

# Setup optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=trainer_config.learning_rate,
    weight_decay=trainer_config.weight_decay,
)

# Setup learning rate scheduler
total_steps = len(dataloader) * trainer_config.max_epochs
warmup_steps = int(total_steps * trainer_config.warmup_ratio)

if trainer_config.lr_scheduler == "cosine":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=total_steps - warmup_steps
    )
elif trainer_config.lr_scheduler == "linear":
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, 
        start_factor=1.0,
        end_factor=0.1,
        total_iters=total_steps - warmup_steps
    )
elif trainer_config.lr_scheduler == "constant":
    scheduler = torch.optim.lr_scheduler.ConstantLR(
        optimizer, factor=1.0, total_iters=total_steps
    )
else:
    raise ValueError(f"Unknown scheduler: {trainer_config.lr_scheduler}")

# Warmup scheduler
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=warmup_steps
)

# Setup loss function
criterion = torch.nn.MSELoss()

# Setup mixed precision training
scaler = torch.cuda.amp.GradScaler(enabled=trainer_config.mixed_precision)

# Setup model
model.to(device)

# Tracking variables
global_step = 0
best_val_loss = float('inf')
epochs_without_improvement = 0

  scaler = torch.cuda.amp.GradScaler(enabled=trainer_config.mixed_precision)


In [32]:
model.train()
total_loss = 0.0
epoch = 1

with tqdm(dataloader, unit="batch", desc=f"Epoch {epoch}") as progress_bar:
        for batch_idx, batch in enumerate(progress_bar):
            # Move batch to device
            batch_encounters = batch['encounter_name']
            del batch['encounter_name']
            batch = {k: v.to(device) for k, v in batch.items()}
            
            batch['ehr'][torch.isinf(batch['ehr'])] = 0
            
            # Decide on teacher forcing
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
            target_input = batch["target"][:, :-1] if use_teacher_forcing else None

            # Forward pass with mixed precision
            with torch.cuda.amp.autocast(enabled=trainer_config.mixed_precision):
                outputs = model(
                    ehr=batch["ehr"],
                    prev_cxr=batch["prev_cxr"],
                    target_input=None,
                    encoder_attention_mask=batch.get("attention_mask"),
                    decoder_attention_mask=batch.get("attention_mask"),
                    causal_mask=True
                )
                loss = criterion(outputs, batch["target"])
            if torch.isnan(loss):
                break

            # Backward pass with gradient scaling
            optimizer.zero_grad()
            scaler.scale(loss).backward()

            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), trainer_config.grad_norm_clip)

            # Update weights
            scaler.step(optimizer)
            scaler.update()

            # Update learning rate
            if global_step < warmup_steps:
                warmup_scheduler.step()
            else:
                scheduler.step()

            # Update metrics
            total_loss += loss.item()
            global_step += 1

            # Update progress bar
            progress_bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

            # Log metrics
            if batch_idx % trainer_config.log_every == 0:
                print({
                    "train/batch_loss": loss.item(),
                    "train/learning_rate": optimizer.param_groups[0]["lr"],
                    "train/global_step": global_step,
                })

avg_loss = total_loss / len(dataloader)

  with torch.cuda.amp.autocast(enabled=trainer_config.mixed_precision):
Epoch 1:   0%|          | 1/2212 [00:04<2:30:02,  4.07s/batch, loss=5.97, lr=1.11e-5]

{'train/batch_loss': 5.967195510864258, 'train/learning_rate': 1.1106690777576851e-05, 'train/global_step': 136}


Epoch 1:   0%|          | 11/2212 [00:09<15:26,  2.38batch/s, loss=5.89, lr=1.12e-5] 

{'train/batch_loss': 5.893756866455078, 'train/learning_rate': 1.1188065099457504e-05, 'train/global_step': 146}


Epoch 1:   1%|          | 21/2212 [00:12<13:50,  2.64batch/s, loss=6.52, lr=1.13e-5]

{'train/batch_loss': 6.5240983963012695, 'train/learning_rate': 1.1269439421338157e-05, 'train/global_step': 156}


Epoch 1:   1%|▏         | 31/2212 [00:16<11:05,  3.28batch/s, loss=5.68, lr=1.14e-5]

{'train/batch_loss': 5.675389289855957, 'train/learning_rate': 1.1350813743218806e-05, 'train/global_step': 166}


Epoch 1:   2%|▏         | 41/2212 [00:19<10:17,  3.51batch/s, loss=5.71, lr=1.14e-5]

{'train/batch_loss': 5.712188720703125, 'train/learning_rate': 1.1432188065099458e-05, 'train/global_step': 176}


Epoch 1:   2%|▏         | 51/2212 [00:22<12:50,  2.80batch/s, loss=5.2, lr=1.15e-5] 

{'train/batch_loss': 5.204791069030762, 'train/learning_rate': 1.1513562386980113e-05, 'train/global_step': 186}


Epoch 1:   3%|▎         | 61/2212 [00:26<10:58,  3.27batch/s, loss=4.4, lr=1.16e-5] 

{'train/batch_loss': 4.395272731781006, 'train/learning_rate': 1.1594936708860759e-05, 'train/global_step': 196}


Epoch 1:   3%|▎         | 71/2212 [00:29<10:32,  3.38batch/s, loss=4.82, lr=1.17e-5]

{'train/batch_loss': 4.817283630371094, 'train/learning_rate': 1.167631103074141e-05, 'train/global_step': 206}


Epoch 1:   4%|▎         | 81/2212 [00:32<10:12,  3.48batch/s, loss=6.53, lr=1.18e-5]

{'train/batch_loss': 6.530384063720703, 'train/learning_rate': 1.1757685352622057e-05, 'train/global_step': 216}


Epoch 1:   4%|▍         | 91/2212 [00:35<09:23,  3.76batch/s, loss=6.03, lr=1.18e-5]

{'train/batch_loss': 6.025198936462402, 'train/learning_rate': 1.1839059674502706e-05, 'train/global_step': 226}


Epoch 1:   5%|▍         | 101/2212 [00:38<10:05,  3.48batch/s, loss=5.49, lr=1.19e-5]

{'train/batch_loss': 5.487472057342529, 'train/learning_rate': 1.1920433996383356e-05, 'train/global_step': 236}


Epoch 1:   5%|▌         | 111/2212 [00:42<11:46,  2.97batch/s, loss=5.67, lr=1.2e-5] 

{'train/batch_loss': 5.666959285736084, 'train/learning_rate': 1.2001808318264005e-05, 'train/global_step': 246}


Epoch 1:   5%|▌         | 121/2212 [00:44<09:40,  3.60batch/s, loss=5.35, lr=1.21e-5]

{'train/batch_loss': 5.347719669342041, 'train/learning_rate': 1.2083182640144654e-05, 'train/global_step': 256}


Epoch 1:   6%|▌         | 131/2212 [00:48<10:05,  3.44batch/s, loss=5.65, lr=1.22e-5]

{'train/batch_loss': 5.652981758117676, 'train/learning_rate': 1.2164556962025307e-05, 'train/global_step': 266}


Epoch 1:   6%|▋         | 141/2212 [00:51<10:23,  3.32batch/s, loss=5.18, lr=1.22e-5]

{'train/batch_loss': 5.1848955154418945, 'train/learning_rate': 1.2245931283905956e-05, 'train/global_step': 276}


Epoch 1:   7%|▋         | 151/2212 [00:54<11:27,  3.00batch/s, loss=4.76, lr=1.23e-5]

{'train/batch_loss': 4.7589111328125, 'train/learning_rate': 1.2327305605786602e-05, 'train/global_step': 286}


Epoch 1:   7%|▋         | 161/2212 [00:57<10:59,  3.11batch/s, loss=4.41, lr=1.24e-5]

{'train/batch_loss': 4.4064507484436035, 'train/learning_rate': 1.2408679927667254e-05, 'train/global_step': 296}


Epoch 1:   8%|▊         | 171/2212 [01:00<10:40,  3.18batch/s, loss=3.34, lr=1.25e-5]

{'train/batch_loss': 3.342633008956909, 'train/learning_rate': 1.2490054249547902e-05, 'train/global_step': 306}


Epoch 1:   8%|▊         | 181/2212 [01:04<11:31,  2.94batch/s, loss=3.25, lr=1.26e-5]

{'train/batch_loss': 3.246659278869629, 'train/learning_rate': 1.2571428571428551e-05, 'train/global_step': 316}


Epoch 1:   9%|▊         | 191/2212 [01:07<09:23,  3.59batch/s, loss=5.16, lr=1.27e-5]

{'train/batch_loss': 5.155475616455078, 'train/learning_rate': 1.2652802893309199e-05, 'train/global_step': 326}


Epoch 1:   9%|▉         | 201/2212 [01:09<09:11,  3.65batch/s, loss=4.2, lr=1.27e-5] 

{'train/batch_loss': 4.199784755706787, 'train/learning_rate': 1.2734177215189846e-05, 'train/global_step': 336}


Epoch 1:  10%|▉         | 211/2212 [01:12<09:07,  3.65batch/s, loss=4.39, lr=1.28e-5]

{'train/batch_loss': 4.388845443725586, 'train/learning_rate': 1.2815551537070499e-05, 'train/global_step': 346}


Epoch 1:  10%|▉         | 221/2212 [01:15<09:05,  3.65batch/s, loss=4.24, lr=1.29e-5]

{'train/batch_loss': 4.242391586303711, 'train/learning_rate': 1.2896925858951147e-05, 'train/global_step': 356}


Epoch 1:  10%|█         | 231/2212 [01:18<09:04,  3.64batch/s, loss=3.57, lr=1.3e-5] 

{'train/batch_loss': 3.57267427444458, 'train/learning_rate': 1.2978300180831804e-05, 'train/global_step': 366}


Epoch 1:  11%|█         | 241/2212 [01:21<11:30,  2.86batch/s, loss=3.18, lr=1.31e-5]

{'train/batch_loss': 3.177051067352295, 'train/learning_rate': 1.3059674502712459e-05, 'train/global_step': 376}


Epoch 1:  11%|█▏        | 251/2212 [01:24<10:19,  3.16batch/s, loss=4.46, lr=1.31e-5]

{'train/batch_loss': 4.460968017578125, 'train/learning_rate': 1.3141048824593106e-05, 'train/global_step': 386}


Epoch 1:  12%|█▏        | 261/2212 [01:27<09:31,  3.42batch/s, loss=4.7, lr=1.32e-5] 

{'train/batch_loss': 4.7039947509765625, 'train/learning_rate': 1.3222423146473755e-05, 'train/global_step': 396}


Epoch 1:  12%|█▏        | 271/2212 [01:30<08:59,  3.60batch/s, loss=3.39, lr=1.33e-5]

{'train/batch_loss': 3.3918535709381104, 'train/learning_rate': 1.3303797468354401e-05, 'train/global_step': 406}


Epoch 1:  13%|█▎        | 281/2212 [01:33<10:09,  3.17batch/s, loss=3.88, lr=1.34e-5]

{'train/batch_loss': 3.876826286315918, 'train/learning_rate': 1.3385171790235052e-05, 'train/global_step': 416}


Epoch 1:  13%|█▎        | 291/2212 [01:36<09:50,  3.25batch/s, loss=3.19, lr=1.35e-5]

{'train/batch_loss': 3.187915563583374, 'train/learning_rate': 1.3466546112115705e-05, 'train/global_step': 426}


Epoch 1:  14%|█▎        | 301/2212 [01:39<12:09,  2.62batch/s, loss=3.19, lr=1.35e-5]

{'train/batch_loss': 3.1895246505737305, 'train/learning_rate': 1.3547920433996354e-05, 'train/global_step': 436}


Epoch 1:  14%|█▍        | 311/2212 [01:42<09:29,  3.34batch/s, loss=3.99, lr=1.36e-5]

{'train/batch_loss': 3.990095853805542, 'train/learning_rate': 1.3629294755877003e-05, 'train/global_step': 446}


Epoch 1:  15%|█▍        | 321/2212 [01:45<09:50,  3.20batch/s, loss=3.89, lr=1.37e-5]

{'train/batch_loss': 3.891892433166504, 'train/learning_rate': 1.3710669077757647e-05, 'train/global_step': 456}


Epoch 1:  15%|█▍        | 331/2212 [01:48<08:51,  3.54batch/s, loss=3.94, lr=1.38e-5]

{'train/batch_loss': 3.9417285919189453, 'train/learning_rate': 1.3792043399638298e-05, 'train/global_step': 466}


Epoch 1:  15%|█▌        | 341/2212 [01:52<08:56,  3.48batch/s, loss=3.85, lr=1.39e-5]

{'train/batch_loss': 3.846552848815918, 'train/learning_rate': 1.3873417721518946e-05, 'train/global_step': 476}


Epoch 1:  16%|█▌        | 351/2212 [01:55<08:48,  3.52batch/s, loss=3.89, lr=1.4e-5] 

{'train/batch_loss': 3.8945860862731934, 'train/learning_rate': 1.3954792043399597e-05, 'train/global_step': 486}


Epoch 1:  16%|█▋        | 361/2212 [01:57<08:19,  3.71batch/s, loss=3.76, lr=1.4e-5]

{'train/batch_loss': 3.760141611099243, 'train/learning_rate': 1.403616636528025e-05, 'train/global_step': 496}


Epoch 1:  17%|█▋        | 371/2212 [02:01<10:39,  2.88batch/s, loss=3.15, lr=1.41e-5]

{'train/batch_loss': 3.151137590408325, 'train/learning_rate': 1.41175406871609e-05, 'train/global_step': 506}


Epoch 1:  17%|█▋        | 381/2212 [02:04<09:15,  3.29batch/s, loss=2.92, lr=1.42e-5]

{'train/batch_loss': 2.9232776165008545, 'train/learning_rate': 1.4198915009041551e-05, 'train/global_step': 516}


Epoch 1:  18%|█▊        | 391/2212 [02:07<10:06,  3.00batch/s, loss=3.91, lr=1.43e-5]

{'train/batch_loss': 3.913337469100952, 'train/learning_rate': 1.4280289330922206e-05, 'train/global_step': 526}


Epoch 1:  18%|█▊        | 401/2212 [02:10<08:33,  3.53batch/s, loss=3.76, lr=1.44e-5]

{'train/batch_loss': 3.7635393142700195, 'train/learning_rate': 1.4361663652802862e-05, 'train/global_step': 536}


Epoch 1:  19%|█▊        | 411/2212 [02:13<08:45,  3.43batch/s, loss=3.64, lr=1.44e-5]

{'train/batch_loss': 3.6407220363616943, 'train/learning_rate': 1.4443037974683511e-05, 'train/global_step': 546}


Epoch 1:  19%|█▉        | 421/2212 [02:16<08:32,  3.49batch/s, loss=2.9, lr=1.45e-5] 

{'train/batch_loss': 2.8981869220733643, 'train/learning_rate': 1.4524412296564164e-05, 'train/global_step': 556}


Epoch 1:  19%|█▉        | 431/2212 [02:19<09:04,  3.27batch/s, loss=2.69, lr=1.46e-5]

{'train/batch_loss': 2.689687490463257, 'train/learning_rate': 1.4605786618444811e-05, 'train/global_step': 566}


Epoch 1:  20%|█▉        | 441/2212 [02:22<09:16,  3.18batch/s, loss=3.2, lr=1.47e-5] 

{'train/batch_loss': 3.2041375637054443, 'train/learning_rate': 1.4687160940325457e-05, 'train/global_step': 576}


Epoch 1:  20%|██        | 451/2212 [02:25<07:58,  3.68batch/s, loss=2.8, lr=1.48e-5] 

{'train/batch_loss': 2.8044557571411133, 'train/learning_rate': 1.4768535262206103e-05, 'train/global_step': 586}


Epoch 1:  21%|██        | 461/2212 [02:28<09:22,  3.11batch/s, loss=3.05, lr=1.48e-5]

{'train/batch_loss': 3.0477161407470703, 'train/learning_rate': 1.4849909584086754e-05, 'train/global_step': 596}


Epoch 1:  21%|██▏       | 471/2212 [02:31<10:21,  2.80batch/s, loss=3.17, lr=1.49e-5]

{'train/batch_loss': 3.1690409183502197, 'train/learning_rate': 1.4931283905967412e-05, 'train/global_step': 606}


Epoch 1:  22%|██▏       | 481/2212 [02:35<12:04,  2.39batch/s, loss=2.93, lr=1.5e-5] 

{'train/batch_loss': 2.9284048080444336, 'train/learning_rate': 1.5012658227848066e-05, 'train/global_step': 616}


Epoch 1:  22%|██▏       | 491/2212 [02:38<08:40,  3.31batch/s, loss=2.66, lr=1.51e-5]

{'train/batch_loss': 2.663022518157959, 'train/learning_rate': 1.5094032549728712e-05, 'train/global_step': 626}


Epoch 1:  23%|██▎       | 501/2212 [02:41<08:11,  3.48batch/s, loss=2.91, lr=1.52e-5]

{'train/batch_loss': 2.912884473800659, 'train/learning_rate': 1.517540687160937e-05, 'train/global_step': 636}


Epoch 1:  23%|██▎       | 511/2212 [02:44<08:47,  3.23batch/s, loss=1.89, lr=1.53e-5]

{'train/batch_loss': 1.894087553024292, 'train/learning_rate': 1.5256781193490024e-05, 'train/global_step': 646}


Epoch 1:  24%|██▎       | 521/2212 [02:47<07:59,  3.52batch/s, loss=2.88, lr=1.53e-5]

{'train/batch_loss': 2.880065679550171, 'train/learning_rate': 1.5338155515370676e-05, 'train/global_step': 656}


Epoch 1:  24%|██▍       | 531/2212 [02:50<09:59,  2.80batch/s, loss=2.18, lr=1.54e-5]

{'train/batch_loss': 2.177847146987915, 'train/learning_rate': 1.5419529837251322e-05, 'train/global_step': 666}


Epoch 1:  24%|██▍       | 541/2212 [02:53<07:59,  3.48batch/s, loss=2.55, lr=1.55e-5]

{'train/batch_loss': 2.549988031387329, 'train/learning_rate': 1.5500904159131968e-05, 'train/global_step': 676}


Epoch 1:  25%|██▍       | 551/2212 [02:57<12:00,  2.31batch/s, loss=2.8, lr=1.56e-5] 

{'train/batch_loss': 2.797680139541626, 'train/learning_rate': 1.5582278481012614e-05, 'train/global_step': 686}


Epoch 1:  25%|██▌       | 561/2212 [03:00<09:08,  3.01batch/s, loss=2.24, lr=1.57e-5]

{'train/batch_loss': 2.242486000061035, 'train/learning_rate': 1.566365280289326e-05, 'train/global_step': 696}


Epoch 1:  26%|██▌       | 571/2212 [03:04<08:17,  3.30batch/s, loss=2.3, lr=1.57e-5] 

{'train/batch_loss': 2.3042643070220947, 'train/learning_rate': 1.574502712477391e-05, 'train/global_step': 706}


Epoch 1:  26%|██▋       | 581/2212 [03:07<07:49,  3.47batch/s, loss=2.87, lr=1.58e-5]

{'train/batch_loss': 2.8668417930603027, 'train/learning_rate': 1.582640144665457e-05, 'train/global_step': 716}


Epoch 1:  27%|██▋       | 591/2212 [03:10<08:41,  3.11batch/s, loss=3.12, lr=1.59e-5]

{'train/batch_loss': 3.123838186264038, 'train/learning_rate': 1.590777576853522e-05, 'train/global_step': 726}


Epoch 1:  27%|██▋       | 601/2212 [03:13<07:49,  3.43batch/s, loss=2.28, lr=1.6e-5] 

{'train/batch_loss': 2.2801268100738525, 'train/learning_rate': 1.598915009041587e-05, 'train/global_step': 736}


Epoch 1:  28%|██▊       | 611/2212 [03:17<09:38,  2.77batch/s, loss=2.1, lr=1.61e-5] 

{'train/batch_loss': 2.0968313217163086, 'train/learning_rate': 1.6070524412296523e-05, 'train/global_step': 746}


Epoch 1:  28%|██▊       | 621/2212 [03:20<10:00,  2.65batch/s, loss=1.64, lr=1.62e-5]

{'train/batch_loss': 1.635006308555603, 'train/learning_rate': 1.615189873417717e-05, 'train/global_step': 756}


Epoch 1:  29%|██▊       | 631/2212 [03:23<08:51,  2.97batch/s, loss=2.95, lr=1.62e-5]

{'train/batch_loss': 2.9500157833099365, 'train/learning_rate': 1.6233273056057818e-05, 'train/global_step': 766}


Epoch 1:  29%|██▉       | 641/2212 [03:27<07:53,  3.32batch/s, loss=2.57, lr=1.63e-5]

{'train/batch_loss': 2.5667288303375244, 'train/learning_rate': 1.6314647377938464e-05, 'train/global_step': 776}


Epoch 1:  29%|██▉       | 651/2212 [03:30<07:20,  3.54batch/s, loss=2.02, lr=1.64e-5]

{'train/batch_loss': 2.016770601272583, 'train/learning_rate': 1.639602169981912e-05, 'train/global_step': 786}


Epoch 1:  30%|██▉       | 661/2212 [03:32<07:19,  3.53batch/s, loss=2.2, lr=1.65e-5] 

{'train/batch_loss': 2.2004778385162354, 'train/learning_rate': 1.6477396021699776e-05, 'train/global_step': 796}


Epoch 1:  30%|███       | 671/2212 [03:36<09:25,  2.73batch/s, loss=2.19, lr=1.66e-5]

{'train/batch_loss': 2.1855506896972656, 'train/learning_rate': 1.6558770343580425e-05, 'train/global_step': 806}


Epoch 1:  31%|███       | 681/2212 [03:39<07:25,  3.43batch/s, loss=1.96, lr=1.66e-5]

{'train/batch_loss': 1.9568480253219604, 'train/learning_rate': 1.664014466546108e-05, 'train/global_step': 816}


Epoch 1:  31%|███       | 691/2212 [03:43<09:56,  2.55batch/s, loss=1.89, lr=1.67e-5]

{'train/batch_loss': 1.8905129432678223, 'train/learning_rate': 1.672151898734173e-05, 'train/global_step': 826}


Epoch 1:  32%|███▏      | 701/2212 [03:46<08:45,  2.88batch/s, loss=2.11, lr=1.68e-5]

{'train/batch_loss': 2.1069343090057373, 'train/learning_rate': 1.6802893309222387e-05, 'train/global_step': 836}


Epoch 1:  32%|███▏      | 711/2212 [03:50<09:29,  2.63batch/s, loss=2, lr=1.69e-5]   

{'train/batch_loss': 2.003554344177246, 'train/learning_rate': 1.6884267631103036e-05, 'train/global_step': 846}


Epoch 1:  33%|███▎      | 721/2212 [03:53<07:45,  3.20batch/s, loss=2.61, lr=1.7e-5] 

{'train/batch_loss': 2.6100096702575684, 'train/learning_rate': 1.696564195298369e-05, 'train/global_step': 856}


Epoch 1:  33%|███▎      | 731/2212 [03:56<07:33,  3.27batch/s, loss=2.75, lr=1.7e-5]

{'train/batch_loss': 2.7528066635131836, 'train/learning_rate': 1.7047016274864338e-05, 'train/global_step': 866}


Epoch 1:  33%|███▎      | 741/2212 [03:59<08:09,  3.01batch/s, loss=1.95, lr=1.71e-5]

{'train/batch_loss': 1.948580026626587, 'train/learning_rate': 1.712839059674498e-05, 'train/global_step': 876}


Epoch 1:  34%|███▍      | 751/2212 [04:03<07:52,  3.09batch/s, loss=1.62, lr=1.72e-5]

{'train/batch_loss': 1.6207503080368042, 'train/learning_rate': 1.7209764918625636e-05, 'train/global_step': 886}


Epoch 1:  34%|███▍      | 761/2212 [04:06<07:08,  3.39batch/s, loss=2.41, lr=1.73e-5]

{'train/batch_loss': 2.4105887413024902, 'train/learning_rate': 1.7291139240506296e-05, 'train/global_step': 896}


Epoch 1:  35%|███▍      | 771/2212 [04:09<08:51,  2.71batch/s, loss=1.94, lr=1.74e-5]

{'train/batch_loss': 1.9409290552139282, 'train/learning_rate': 1.7372513562386945e-05, 'train/global_step': 906}


Epoch 1:  35%|███▌      | 781/2212 [04:12<07:01,  3.40batch/s, loss=2.04, lr=1.75e-5]

{'train/batch_loss': 2.0360231399536133, 'train/learning_rate': 1.7453887884267598e-05, 'train/global_step': 916}


Epoch 1:  36%|███▌      | 791/2212 [04:15<06:35,  3.59batch/s, loss=1.86, lr=1.75e-5]

{'train/batch_loss': 1.8569406270980835, 'train/learning_rate': 1.753526220614825e-05, 'train/global_step': 926}


Epoch 1:  36%|███▌      | 801/2212 [04:18<06:56,  3.38batch/s, loss=1.85, lr=1.76e-5]

{'train/batch_loss': 1.849409580230713, 'train/learning_rate': 1.7616636528028893e-05, 'train/global_step': 936}


Epoch 1:  37%|███▋      | 811/2212 [04:21<07:24,  3.15batch/s, loss=2.43, lr=1.77e-5]

{'train/batch_loss': 2.426612377166748, 'train/learning_rate': 1.769801084990955e-05, 'train/global_step': 946}


Epoch 1:  37%|███▋      | 821/2212 [04:24<08:42,  2.66batch/s, loss=1.11, lr=1.78e-5]

{'train/batch_loss': 1.1074830293655396, 'train/learning_rate': 1.77793851717902e-05, 'train/global_step': 956}


Epoch 1:  38%|███▊      | 831/2212 [04:28<06:47,  3.39batch/s, loss=1.65, lr=1.79e-5]

{'train/batch_loss': 1.645728588104248, 'train/learning_rate': 1.786075949367084e-05, 'train/global_step': 966}


Epoch 1:  38%|███▊      | 841/2212 [04:31<08:59,  2.54batch/s, loss=1.96, lr=1.79e-5]

{'train/batch_loss': 1.9578452110290527, 'train/learning_rate': 1.794213381555149e-05, 'train/global_step': 976}


Epoch 1:  38%|███▊      | 851/2212 [04:34<07:58,  2.84batch/s, loss=2.21, lr=1.8e-5] 

{'train/batch_loss': 2.2148468494415283, 'train/learning_rate': 1.802350813743214e-05, 'train/global_step': 986}


Epoch 1:  39%|███▉      | 861/2212 [04:38<06:44,  3.34batch/s, loss=1.39, lr=1.81e-5]

{'train/batch_loss': 1.3931869268417358, 'train/learning_rate': 1.8104882459312788e-05, 'train/global_step': 996}


Epoch 1:  39%|███▉      | 871/2212 [04:41<07:45,  2.88batch/s, loss=2.64, lr=1.82e-5]

{'train/batch_loss': 2.641312837600708, 'train/learning_rate': 1.8186256781193437e-05, 'train/global_step': 1006}


Epoch 1:  40%|███▉      | 881/2212 [04:44<08:23,  2.64batch/s, loss=1.68, lr=1.83e-5]

{'train/batch_loss': 1.678617238998413, 'train/learning_rate': 1.8267631103074083e-05, 'train/global_step': 1016}


Epoch 1:  40%|████      | 891/2212 [04:48<07:52,  2.79batch/s, loss=2.16, lr=1.83e-5]

{'train/batch_loss': 2.1560802459716797, 'train/learning_rate': 1.8349005424954743e-05, 'train/global_step': 1026}


Epoch 1:  41%|████      | 901/2212 [04:51<07:46,  2.81batch/s, loss=1.66, lr=1.84e-5]

{'train/batch_loss': 1.656196117401123, 'train/learning_rate': 1.843037974683539e-05, 'train/global_step': 1036}


Epoch 1:  41%|████      | 911/2212 [04:55<07:07,  3.04batch/s, loss=1.46, lr=1.85e-5] 

{'train/batch_loss': 1.4575883150100708, 'train/learning_rate': 1.8511754068716045e-05, 'train/global_step': 1046}


Epoch 1:  42%|████▏     | 921/2212 [04:58<06:47,  3.17batch/s, loss=2.15, lr=1.86e-5]

{'train/batch_loss': 2.151710271835327, 'train/learning_rate': 1.859312839059669e-05, 'train/global_step': 1056}


Epoch 1:  42%|████▏     | 931/2212 [05:01<07:07,  3.00batch/s, loss=1.31, lr=1.87e-5]

{'train/batch_loss': 1.3117010593414307, 'train/learning_rate': 1.8674502712477343e-05, 'train/global_step': 1066}


Epoch 1:  43%|████▎     | 941/2212 [05:05<07:22,  2.87batch/s, loss=1.52, lr=1.88e-5]

{'train/batch_loss': 1.520783543586731, 'train/learning_rate': 1.8755877034358e-05, 'train/global_step': 1076}


Epoch 1:  43%|████▎     | 951/2212 [05:08<06:50,  3.07batch/s, loss=1.51, lr=1.88e-5]

{'train/batch_loss': 1.509740948677063, 'train/learning_rate': 1.883725135623865e-05, 'train/global_step': 1086}


Epoch 1:  43%|████▎     | 961/2212 [05:11<05:54,  3.53batch/s, loss=1.7, lr=1.89e-5]  

{'train/batch_loss': 1.700475811958313, 'train/learning_rate': 1.8918625678119304e-05, 'train/global_step': 1096}


Epoch 1:  44%|████▍     | 971/2212 [05:14<06:49,  3.03batch/s, loss=1.41, lr=1.9e-5] 

{'train/batch_loss': 1.4062010049819946, 'train/learning_rate': 1.899999999999995e-05, 'train/global_step': 1106}


Epoch 1:  44%|████▍     | 981/2212 [05:18<09:36,  2.14batch/s, loss=1.4, lr=1.91e-5] 

{'train/batch_loss': 1.4043090343475342, 'train/learning_rate': 1.9081374321880593e-05, 'train/global_step': 1116}


Epoch 1:  45%|████▍     | 991/2212 [05:21<06:03,  3.36batch/s, loss=1.42, lr=1.92e-5]

{'train/batch_loss': 1.4174810647964478, 'train/learning_rate': 1.916274864376124e-05, 'train/global_step': 1126}


Epoch 1:  45%|████▌     | 1001/2212 [05:25<07:07,  2.84batch/s, loss=1.39, lr=1.92e-5]

{'train/batch_loss': 1.3876947164535522, 'train/learning_rate': 1.924412296564189e-05, 'train/global_step': 1136}


Epoch 1:  46%|████▌     | 1011/2212 [05:28<06:29,  3.09batch/s, loss=1.97, lr=1.93e-5]

{'train/batch_loss': 1.965153455734253, 'train/learning_rate': 1.9325497287522544e-05, 'train/global_step': 1146}


Epoch 1:  46%|████▌     | 1021/2212 [05:31<05:40,  3.50batch/s, loss=1.36, lr=1.94e-5]

{'train/batch_loss': 1.3603887557983398, 'train/learning_rate': 1.9406871609403193e-05, 'train/global_step': 1156}


Epoch 1:  47%|████▋     | 1031/2212 [05:34<06:06,  3.23batch/s, loss=1.48, lr=1.95e-5]

{'train/batch_loss': 1.4753010272979736, 'train/learning_rate': 1.948824593128385e-05, 'train/global_step': 1166}


Epoch 1:  47%|████▋     | 1041/2212 [05:37<05:32,  3.52batch/s, loss=1.67, lr=1.96e-5]

{'train/batch_loss': 1.6733242273330688, 'train/learning_rate': 1.95696202531645e-05, 'train/global_step': 1176}


Epoch 1:  48%|████▊     | 1051/2212 [05:40<06:18,  3.07batch/s, loss=1.67, lr=1.97e-5]

{'train/batch_loss': 1.6650854349136353, 'train/learning_rate': 1.9650994575045148e-05, 'train/global_step': 1186}


Epoch 1:  48%|████▊     | 1061/2212 [05:43<05:44,  3.34batch/s, loss=2.04, lr=1.97e-5]

{'train/batch_loss': 2.0379467010498047, 'train/learning_rate': 1.9732368896925797e-05, 'train/global_step': 1196}


Epoch 1:  48%|████▊     | 1071/2212 [05:47<05:48,  3.28batch/s, loss=1.68, lr=1.98e-5]

{'train/batch_loss': 1.6787199974060059, 'train/learning_rate': 1.9813743218806453e-05, 'train/global_step': 1206}


Epoch 1:  49%|████▉     | 1081/2212 [05:50<06:18,  2.98batch/s, loss=1.87, lr=1.99e-5]

{'train/batch_loss': 1.8678960800170898, 'train/learning_rate': 1.9895117540687095e-05, 'train/global_step': 1216}


Epoch 1:  49%|████▉     | 1091/2212 [05:53<06:29,  2.88batch/s, loss=1.86, lr=2e-5]   

{'train/batch_loss': 1.8628824949264526, 'train/learning_rate': 1.9976491862567745e-05, 'train/global_step': 1226}


Epoch 1:  50%|████▉     | 1101/2212 [05:57<07:24,  2.50batch/s, loss=1.87, lr=2.01e-5]

{'train/batch_loss': 1.8710132837295532, 'train/learning_rate': 2.0057866184448394e-05, 'train/global_step': 1236}


Epoch 1:  50%|█████     | 1111/2212 [06:00<05:05,  3.61batch/s, loss=1.89, lr=2.01e-5] 

{'train/batch_loss': 1.8886616230010986, 'train/learning_rate': 2.013924050632904e-05, 'train/global_step': 1246}


Epoch 1:  51%|█████     | 1121/2212 [06:03<05:54,  3.08batch/s, loss=1.95, lr=2.02e-5]

{'train/batch_loss': 1.9458777904510498, 'train/learning_rate': 2.0220614828209686e-05, 'train/global_step': 1256}


Epoch 1:  51%|█████     | 1131/2212 [06:06<05:54,  3.05batch/s, loss=1.18, lr=2.03e-5]

{'train/batch_loss': 1.175872564315796, 'train/learning_rate': 2.0301989150090335e-05, 'train/global_step': 1266}


Epoch 1:  52%|█████▏    | 1141/2212 [06:10<05:29,  3.25batch/s, loss=1.47, lr=2.04e-5]

{'train/batch_loss': 1.469742774963379, 'train/learning_rate': 2.038336347197098e-05, 'train/global_step': 1276}


Epoch 1:  52%|█████▏    | 1151/2212 [06:13<05:55,  2.99batch/s, loss=2.03, lr=2.05e-5]

{'train/batch_loss': 2.0340871810913086, 'train/learning_rate': 2.046473779385163e-05, 'train/global_step': 1286}


Epoch 1:  52%|█████▏    | 1161/2212 [06:16<06:04,  2.88batch/s, loss=1.06, lr=2.05e-5]

{'train/batch_loss': 1.0551247596740723, 'train/learning_rate': 2.0546112115732286e-05, 'train/global_step': 1296}


Epoch 1:  53%|█████▎    | 1171/2212 [06:20<05:40,  3.05batch/s, loss=1.29, lr=2.06e-5] 

{'train/batch_loss': 1.2911087274551392, 'train/learning_rate': 2.0627486437612932e-05, 'train/global_step': 1306}


Epoch 1:  53%|█████▎    | 1181/2212 [06:23<05:08,  3.34batch/s, loss=1.73, lr=2.07e-5]

{'train/batch_loss': 1.7304563522338867, 'train/learning_rate': 2.0708860759493578e-05, 'train/global_step': 1316}


Epoch 1:  54%|█████▍    | 1191/2212 [06:26<04:42,  3.61batch/s, loss=1.44, lr=2.08e-5] 

{'train/batch_loss': 1.4392448663711548, 'train/learning_rate': 2.079023508137423e-05, 'train/global_step': 1326}


Epoch 1:  54%|█████▍    | 1201/2212 [06:29<05:00,  3.36batch/s, loss=1.59, lr=2.09e-5] 

{'train/batch_loss': 1.5934921503067017, 'train/learning_rate': 2.0871609403254883e-05, 'train/global_step': 1336}


Epoch 1:  55%|█████▍    | 1211/2212 [06:32<05:25,  3.08batch/s, loss=1.51, lr=2.1e-5] 

{'train/batch_loss': 1.506731390953064, 'train/learning_rate': 2.095298372513554e-05, 'train/global_step': 1346}


Epoch 1:  55%|█████▌    | 1221/2212 [06:35<05:24,  3.05batch/s, loss=1.22, lr=2.1e-5]

{'train/batch_loss': 1.2158914804458618, 'train/learning_rate': 2.103435804701618e-05, 'train/global_step': 1356}


Epoch 1:  56%|█████▌    | 1231/2212 [06:39<07:01,  2.33batch/s, loss=1.5, lr=2.11e-5]  

{'train/batch_loss': 1.5040562152862549, 'train/learning_rate': 2.1115732368896827e-05, 'train/global_step': 1366}


Epoch 1:  56%|█████▌    | 1241/2212 [06:41<04:36,  3.51batch/s, loss=1.83, lr=2.12e-5]

{'train/batch_loss': 1.8291205167770386, 'train/learning_rate': 2.119710669077748e-05, 'train/global_step': 1376}


Epoch 1:  57%|█████▋    | 1251/2212 [06:45<04:43,  3.39batch/s, loss=1.55, lr=2.13e-5] 

{'train/batch_loss': 1.5490059852600098, 'train/learning_rate': 2.1278481012658136e-05, 'train/global_step': 1386}


Epoch 1:  57%|█████▋    | 1261/2212 [06:48<04:46,  3.32batch/s, loss=1.48, lr=2.14e-5]

{'train/batch_loss': 1.4773626327514648, 'train/learning_rate': 2.135985533453879e-05, 'train/global_step': 1396}


Epoch 1:  57%|█████▋    | 1271/2212 [06:51<05:47,  2.71batch/s, loss=1.38, lr=2.14e-5]

{'train/batch_loss': 1.3817486763000488, 'train/learning_rate': 2.1441229656419434e-05, 'train/global_step': 1406}


Epoch 1:  58%|█████▊    | 1281/2212 [06:54<04:49,  3.21batch/s, loss=1.25, lr=2.15e-5] 

{'train/batch_loss': 1.253340721130371, 'train/learning_rate': 2.152260397830008e-05, 'train/global_step': 1416}


Epoch 1:  58%|█████▊    | 1291/2212 [06:57<04:25,  3.46batch/s, loss=1.19, lr=2.16e-5]

{'train/batch_loss': 1.192317008972168, 'train/learning_rate': 2.1603978300180726e-05, 'train/global_step': 1426}


Epoch 1:  59%|█████▉    | 1301/2212 [07:00<04:33,  3.33batch/s, loss=1.4, lr=2.17e-5] 

{'train/batch_loss': 1.3968641757965088, 'train/learning_rate': 2.168535262206138e-05, 'train/global_step': 1436}


Epoch 1:  59%|█████▉    | 1311/2212 [07:03<04:23,  3.41batch/s, loss=0.995, lr=2.18e-5]

{'train/batch_loss': 0.9950806498527527, 'train/learning_rate': 2.1766726943942035e-05, 'train/global_step': 1446}


Epoch 1:  60%|█████▉    | 1321/2212 [07:07<06:52,  2.16batch/s, loss=1.44, lr=2.18e-5] 

{'train/batch_loss': 1.440327763557434, 'train/learning_rate': 2.1848101265822687e-05, 'train/global_step': 1456}


Epoch 1:  60%|██████    | 1331/2212 [07:10<04:39,  3.15batch/s, loss=1.44, lr=2.19e-5]

{'train/batch_loss': 1.4405014514923096, 'train/learning_rate': 2.1929475587703337e-05, 'train/global_step': 1466}


Epoch 1:  61%|██████    | 1341/2212 [07:13<05:15,  2.76batch/s, loss=1.06, lr=2.2e-5]  

{'train/batch_loss': 1.0604562759399414, 'train/learning_rate': 2.2010849909583993e-05, 'train/global_step': 1476}


Epoch 1:  61%|██████    | 1351/2212 [07:17<04:12,  3.40batch/s, loss=1.44, lr=2.21e-5] 

{'train/batch_loss': 1.4389294385910034, 'train/learning_rate': 2.209222423146465e-05, 'train/global_step': 1486}


Epoch 1:  62%|██████▏   | 1361/2212 [07:20<04:37,  3.06batch/s, loss=0.764, lr=2.22e-5]

{'train/batch_loss': 0.7644575238227844, 'train/learning_rate': 2.21735985533453e-05, 'train/global_step': 1496}


Epoch 1:  62%|██████▏   | 1371/2212 [07:23<04:02,  3.46batch/s, loss=1.55, lr=2.23e-5] 

{'train/batch_loss': 1.5533204078674316, 'train/learning_rate': 2.225497287522596e-05, 'train/global_step': 1506}


Epoch 1:  62%|██████▏   | 1381/2212 [07:26<03:55,  3.53batch/s, loss=1.43, lr=2.23e-5] 

{'train/batch_loss': 1.4299813508987427, 'train/learning_rate': 2.233634719710661e-05, 'train/global_step': 1516}


Epoch 1:  63%|██████▎   | 1391/2212 [07:29<04:10,  3.27batch/s, loss=1.11, lr=2.24e-5] 

{'train/batch_loss': 1.1070713996887207, 'train/learning_rate': 2.241772151898726e-05, 'train/global_step': 1526}


Epoch 1:  63%|██████▎   | 1401/2212 [07:32<03:49,  3.54batch/s, loss=0.918, lr=2.25e-5]

{'train/batch_loss': 0.9175230264663696, 'train/learning_rate': 2.249909584086792e-05, 'train/global_step': 1536}


Epoch 1:  64%|██████▍   | 1411/2212 [07:36<04:24,  3.03batch/s, loss=1.46, lr=2.26e-5] 

{'train/batch_loss': 1.4634599685668945, 'train/learning_rate': 2.2580470162748578e-05, 'train/global_step': 1546}


Epoch 1:  64%|██████▍   | 1421/2212 [07:39<03:51,  3.42batch/s, loss=1.18, lr=2.27e-5] 

{'train/batch_loss': 1.176597237586975, 'train/learning_rate': 2.2661844484629234e-05, 'train/global_step': 1556}


Epoch 1:  65%|██████▍   | 1431/2212 [07:42<04:15,  3.06batch/s, loss=1.02, lr=2.27e-5] 

{'train/batch_loss': 1.0195099115371704, 'train/learning_rate': 2.2743218806509877e-05, 'train/global_step': 1566}


Epoch 1:  65%|██████▌   | 1441/2212 [07:45<03:35,  3.57batch/s, loss=1.42, lr=2.28e-5]

{'train/batch_loss': 1.424550175666809, 'train/learning_rate': 2.282459312839053e-05, 'train/global_step': 1576}


Epoch 1:  66%|██████▌   | 1451/2212 [07:48<04:04,  3.11batch/s, loss=1.04, lr=2.29e-5] 

{'train/batch_loss': 1.0444648265838623, 'train/learning_rate': 2.2905967450271172e-05, 'train/global_step': 1586}


Epoch 1:  66%|██████▌   | 1461/2212 [07:52<05:05,  2.46batch/s, loss=1.27, lr=2.3e-5]  

{'train/batch_loss': 1.2672539949417114, 'train/learning_rate': 2.2987341772151824e-05, 'train/global_step': 1596}


Epoch 1:  67%|██████▋   | 1471/2212 [07:55<04:03,  3.04batch/s, loss=1.2, lr=2.31e-5] 

{'train/batch_loss': 1.1976675987243652, 'train/learning_rate': 2.306871609403247e-05, 'train/global_step': 1606}


Epoch 1:  67%|██████▋   | 1481/2212 [07:59<04:11,  2.91batch/s, loss=1.18, lr=2.32e-5] 

{'train/batch_loss': 1.175409197807312, 'train/learning_rate': 2.3150090415913126e-05, 'train/global_step': 1616}


Epoch 1:  67%|██████▋   | 1491/2212 [08:02<03:49,  3.14batch/s, loss=1.07, lr=2.32e-5] 

{'train/batch_loss': 1.071778416633606, 'train/learning_rate': 2.3231464737793782e-05, 'train/global_step': 1626}


Epoch 1:  68%|██████▊   | 1501/2212 [08:06<03:51,  3.07batch/s, loss=0.973, lr=2.33e-5]

{'train/batch_loss': 0.9734837412834167, 'train/learning_rate': 2.331283905967444e-05, 'train/global_step': 1636}


Epoch 1:  68%|██████▊   | 1511/2212 [08:09<04:25,  2.64batch/s, loss=1.24, lr=2.34e-5] 

{'train/batch_loss': 1.2428432703018188, 'train/learning_rate': 2.3394213381555088e-05, 'train/global_step': 1646}


Epoch 1:  69%|██████▉   | 1521/2212 [08:12<04:07,  2.79batch/s, loss=0.987, lr=2.35e-5]

{'train/batch_loss': 0.987263023853302, 'train/learning_rate': 2.3475587703435737e-05, 'train/global_step': 1656}


Epoch 1:  69%|██████▉   | 1531/2212 [08:15<03:29,  3.25batch/s, loss=1.46, lr=2.36e-5] 

{'train/batch_loss': 1.4557534456253052, 'train/learning_rate': 2.3556962025316396e-05, 'train/global_step': 1666}


Epoch 1:  70%|██████▉   | 1541/2212 [08:18<03:26,  3.26batch/s, loss=1.29, lr=2.36e-5] 

{'train/batch_loss': 1.2862331867218018, 'train/learning_rate': 2.3638336347197045e-05, 'train/global_step': 1676}


Epoch 1:  70%|███████   | 1551/2212 [08:22<03:44,  2.95batch/s, loss=0.64, lr=2.37e-5] 

{'train/batch_loss': 0.6396257281303406, 'train/learning_rate': 2.3719710669077705e-05, 'train/global_step': 1686}


Epoch 1:  71%|███████   | 1561/2212 [08:25<03:09,  3.43batch/s, loss=1.11, lr=2.38e-5] 

{'train/batch_loss': 1.1120820045471191, 'train/learning_rate': 2.3801084990958347e-05, 'train/global_step': 1696}


Epoch 1:  71%|███████   | 1571/2212 [08:28<03:35,  2.98batch/s, loss=1.59, lr=2.39e-5] 

{'train/batch_loss': 1.5942089557647705, 'train/learning_rate': 2.3882459312838993e-05, 'train/global_step': 1706}


Epoch 1:  71%|███████▏  | 1581/2212 [08:31<03:03,  3.43batch/s, loss=1.12, lr=2.4e-5]  

{'train/batch_loss': 1.1209536790847778, 'train/learning_rate': 2.3963833634719636e-05, 'train/global_step': 1716}


Epoch 1:  72%|███████▏  | 1591/2212 [08:34<03:05,  3.34batch/s, loss=0.976, lr=2.4e-5]

{'train/batch_loss': 0.9760858416557312, 'train/learning_rate': 2.4045207956600288e-05, 'train/global_step': 1726}


Epoch 1:  72%|███████▏  | 1601/2212 [08:37<03:13,  3.16batch/s, loss=1.19, lr=2.41e-5]

{'train/batch_loss': 1.1866583824157715, 'train/learning_rate': 2.412658227848094e-05, 'train/global_step': 1736}


Epoch 1:  73%|███████▎  | 1611/2212 [08:40<03:01,  3.31batch/s, loss=1.41, lr=2.42e-5] 

{'train/batch_loss': 1.4066531658172607, 'train/learning_rate': 2.4207956600361587e-05, 'train/global_step': 1746}


Epoch 1:  73%|███████▎  | 1621/2212 [08:43<02:52,  3.43batch/s, loss=0.984, lr=2.43e-5]

{'train/batch_loss': 0.9844028949737549, 'train/learning_rate': 2.4289330922242236e-05, 'train/global_step': 1756}


Epoch 1:  74%|███████▎  | 1631/2212 [08:47<03:53,  2.48batch/s, loss=1.08, lr=2.44e-5] 

{'train/batch_loss': 1.0825088024139404, 'train/learning_rate': 2.4370705244122885e-05, 'train/global_step': 1766}


Epoch 1:  74%|███████▍  | 1641/2212 [08:51<03:58,  2.40batch/s, loss=1.29, lr=2.45e-5] 

{'train/batch_loss': 1.2939893007278442, 'train/learning_rate': 2.445207956600354e-05, 'train/global_step': 1776}


Epoch 1:  75%|███████▍  | 1651/2212 [08:54<03:22,  2.77batch/s, loss=1.02, lr=2.45e-5] 

{'train/batch_loss': 1.015323519706726, 'train/learning_rate': 2.4533453887884187e-05, 'train/global_step': 1786}


Epoch 1:  75%|███████▌  | 1661/2212 [08:58<03:20,  2.75batch/s, loss=1.06, lr=2.46e-5] 

{'train/batch_loss': 1.0578553676605225, 'train/learning_rate': 2.461482820976483e-05, 'train/global_step': 1796}


Epoch 1:  76%|███████▌  | 1671/2212 [09:01<02:39,  3.39batch/s, loss=1.15, lr=2.47e-5] 

{'train/batch_loss': 1.1470848321914673, 'train/learning_rate': 2.469620253164549e-05, 'train/global_step': 1806}


Epoch 1:  76%|███████▌  | 1681/2212 [09:04<03:15,  2.71batch/s, loss=1.32, lr=2.48e-5] 

{'train/batch_loss': 1.319627046585083, 'train/learning_rate': 2.4777576853526138e-05, 'train/global_step': 1816}


Epoch 1:  76%|███████▋  | 1691/2212 [09:07<02:35,  3.34batch/s, loss=0.842, lr=2.49e-5]

{'train/batch_loss': 0.8421593308448792, 'train/learning_rate': 2.4858951175406784e-05, 'train/global_step': 1826}


Epoch 1:  77%|███████▋  | 1701/2212 [09:11<02:25,  3.52batch/s, loss=0.944, lr=2.49e-5]

{'train/batch_loss': 0.9439439177513123, 'train/learning_rate': 2.494032549728743e-05, 'train/global_step': 1836}


Epoch 1:  77%|███████▋  | 1711/2212 [09:13<02:20,  3.56batch/s, loss=0.684, lr=2.5e-5] 

{'train/batch_loss': 0.6840324997901917, 'train/learning_rate': 2.502169981916808e-05, 'train/global_step': 1846}


Epoch 1:  78%|███████▊  | 1721/2212 [09:16<02:13,  3.67batch/s, loss=0.711, lr=2.51e-5]

{'train/batch_loss': 0.7111217379570007, 'train/learning_rate': 2.5103074141048735e-05, 'train/global_step': 1856}


Epoch 1:  78%|███████▊  | 1731/2212 [09:20<02:42,  2.95batch/s, loss=0.898, lr=2.52e-5]

{'train/batch_loss': 0.8984262347221375, 'train/learning_rate': 2.5184448462929388e-05, 'train/global_step': 1866}


Epoch 1:  79%|███████▊  | 1741/2212 [09:24<02:56,  2.67batch/s, loss=0.95, lr=2.53e-5] 

{'train/batch_loss': 0.95046067237854, 'train/learning_rate': 2.5265822784810034e-05, 'train/global_step': 1876}


Epoch 1:  79%|███████▉  | 1751/2212 [09:27<02:26,  3.15batch/s, loss=0.925, lr=2.53e-5]

{'train/batch_loss': 0.9250343441963196, 'train/learning_rate': 2.5347197106690683e-05, 'train/global_step': 1886}


Epoch 1:  80%|███████▉  | 1761/2212 [09:31<02:17,  3.29batch/s, loss=0.99, lr=2.54e-5] 

{'train/batch_loss': 0.990172803401947, 'train/learning_rate': 2.5428571428571332e-05, 'train/global_step': 1896}


Epoch 1:  80%|████████  | 1771/2212 [09:34<02:27,  3.00batch/s, loss=1.27, lr=2.55e-5] 

{'train/batch_loss': 1.2707505226135254, 'train/learning_rate': 2.5509945750451978e-05, 'train/global_step': 1906}


Epoch 1:  81%|████████  | 1781/2212 [09:37<02:06,  3.40batch/s, loss=0.982, lr=2.56e-5]

{'train/batch_loss': 0.9822069406509399, 'train/learning_rate': 2.5591320072332627e-05, 'train/global_step': 1916}


Epoch 1:  81%|████████  | 1791/2212 [09:40<02:00,  3.50batch/s, loss=0.954, lr=2.57e-5]

{'train/batch_loss': 0.9536486268043518, 'train/learning_rate': 2.567269439421328e-05, 'train/global_step': 1926}


Epoch 1:  81%|████████▏ | 1801/2212 [09:44<02:35,  2.65batch/s, loss=1.04, lr=2.58e-5] 

{'train/batch_loss': 1.0443915128707886, 'train/learning_rate': 2.5754068716093926e-05, 'train/global_step': 1936}


Epoch 1:  82%|████████▏ | 1811/2212 [09:46<01:55,  3.49batch/s, loss=0.428, lr=2.58e-5]

{'train/batch_loss': 0.4283481538295746, 'train/learning_rate': 2.5835443037974582e-05, 'train/global_step': 1946}


Epoch 1:  82%|████████▏ | 1821/2212 [09:50<02:11,  2.96batch/s, loss=0.97, lr=2.59e-5] 

{'train/batch_loss': 0.970245361328125, 'train/learning_rate': 2.5916817359855228e-05, 'train/global_step': 1956}


Epoch 1:  83%|████████▎ | 1831/2212 [09:52<01:49,  3.48batch/s, loss=1.11, lr=2.6e-5]  

{'train/batch_loss': 1.108572244644165, 'train/learning_rate': 2.5998191681735887e-05, 'train/global_step': 1966}


Epoch 1:  83%|████████▎ | 1841/2212 [09:56<02:35,  2.39batch/s, loss=0.814, lr=2.61e-5]

{'train/batch_loss': 0.8138083815574646, 'train/learning_rate': 2.607956600361654e-05, 'train/global_step': 1976}


Epoch 1:  84%|████████▎ | 1851/2212 [09:59<01:56,  3.10batch/s, loss=0.855, lr=2.62e-5]

{'train/batch_loss': 0.8549678325653076, 'train/learning_rate': 2.6160940325497192e-05, 'train/global_step': 1986}


Epoch 1:  84%|████████▍ | 1861/2212 [10:03<02:03,  2.84batch/s, loss=0.949, lr=2.62e-5]

{'train/batch_loss': 0.9486470222473145, 'train/learning_rate': 2.6242314647377838e-05, 'train/global_step': 1996}


Epoch 1:  85%|████████▍ | 1871/2212 [10:06<01:42,  3.33batch/s, loss=0.853, lr=2.63e-5]

{'train/batch_loss': 0.853274405002594, 'train/learning_rate': 2.6323688969258488e-05, 'train/global_step': 2006}


Epoch 1:  85%|████████▌ | 1881/2212 [10:10<01:43,  3.20batch/s, loss=0.838, lr=2.64e-5]

{'train/batch_loss': 0.8383336663246155, 'train/learning_rate': 2.6405063291139133e-05, 'train/global_step': 2016}


Epoch 1:  85%|████████▌ | 1891/2212 [10:13<01:38,  3.27batch/s, loss=1.43, lr=2.65e-5] 

{'train/batch_loss': 1.4297350645065308, 'train/learning_rate': 2.6486437613019786e-05, 'train/global_step': 2026}


Epoch 1:  86%|████████▌ | 1901/2212 [10:16<01:37,  3.18batch/s, loss=0.848, lr=2.66e-5]

{'train/batch_loss': 0.8480761647224426, 'train/learning_rate': 2.6567811934900432e-05, 'train/global_step': 2036}


Epoch 1:  86%|████████▋ | 1911/2212 [10:20<01:55,  2.62batch/s, loss=0.523, lr=2.66e-5]

{'train/batch_loss': 0.5229091644287109, 'train/learning_rate': 2.6649186256781074e-05, 'train/global_step': 2046}


Epoch 1:  87%|████████▋ | 1921/2212 [10:23<01:34,  3.10batch/s, loss=0.937, lr=2.67e-5]

{'train/batch_loss': 0.9371833801269531, 'train/learning_rate': 2.6730560578661724e-05, 'train/global_step': 2056}


Epoch 1:  87%|████████▋ | 1931/2212 [10:26<01:27,  3.20batch/s, loss=0.861, lr=2.68e-5]

{'train/batch_loss': 0.8608606457710266, 'train/learning_rate': 2.681193490054237e-05, 'train/global_step': 2066}


Epoch 1:  88%|████████▊ | 1941/2212 [10:29<01:21,  3.31batch/s, loss=1.21, lr=2.69e-5] 

{'train/batch_loss': 1.2059448957443237, 'train/learning_rate': 2.689330922242303e-05, 'train/global_step': 2076}


Epoch 1:  88%|████████▊ | 1951/2212 [10:33<01:32,  2.81batch/s, loss=0.816, lr=2.7e-5] 

{'train/batch_loss': 0.8161307573318481, 'train/learning_rate': 2.6974683544303685e-05, 'train/global_step': 2086}


Epoch 1:  89%|████████▊ | 1961/2212 [10:36<01:21,  3.07batch/s, loss=0.858, lr=2.71e-5]

{'train/batch_loss': 0.8583690524101257, 'train/learning_rate': 2.7056057866184334e-05, 'train/global_step': 2096}


Epoch 1:  89%|████████▉ | 1971/2212 [10:39<01:12,  3.33batch/s, loss=0.824, lr=2.71e-5]

{'train/batch_loss': 0.8241732716560364, 'train/learning_rate': 2.713743218806499e-05, 'train/global_step': 2106}


Epoch 1:  90%|████████▉ | 1981/2212 [10:43<01:21,  2.85batch/s, loss=1.17, lr=2.72e-5] 

{'train/batch_loss': 1.1662240028381348, 'train/learning_rate': 2.7218806509945643e-05, 'train/global_step': 2116}


Epoch 1:  90%|█████████ | 1991/2212 [10:47<01:24,  2.63batch/s, loss=0.803, lr=2.73e-5]

{'train/batch_loss': 0.8028278946876526, 'train/learning_rate': 2.73001808318263e-05, 'train/global_step': 2126}


Epoch 1:  90%|█████████ | 2001/2212 [10:50<01:20,  2.63batch/s, loss=0.877, lr=2.74e-5]

{'train/batch_loss': 0.8774856328964233, 'train/learning_rate': 2.7381555153706958e-05, 'train/global_step': 2136}


Epoch 1:  91%|█████████ | 2011/2212 [10:53<01:08,  2.94batch/s, loss=1.13, lr=2.75e-5] 

{'train/batch_loss': 1.1252366304397583, 'train/learning_rate': 2.7462929475587614e-05, 'train/global_step': 2146}


Epoch 1:  91%|█████████▏| 2021/2212 [10:57<01:19,  2.41batch/s, loss=0.952, lr=2.75e-5]

{'train/batch_loss': 0.9515444040298462, 'train/learning_rate': 2.754430379746826e-05, 'train/global_step': 2156}


Epoch 1:  92%|█████████▏| 2031/2212 [11:00<00:59,  3.02batch/s, loss=1.03, lr=2.76e-5] 

{'train/batch_loss': 1.0320278406143188, 'train/learning_rate': 2.7625678119348913e-05, 'train/global_step': 2166}


Epoch 1:  92%|█████████▏| 2041/2212 [11:03<00:50,  3.40batch/s, loss=1.25, lr=2.77e-5] 

{'train/batch_loss': 1.2548080682754517, 'train/learning_rate': 2.7707052441229572e-05, 'train/global_step': 2176}


Epoch 1:  93%|█████████▎| 2051/2212 [11:06<00:50,  3.19batch/s, loss=0.813, lr=2.78e-5]

{'train/batch_loss': 0.8131533861160278, 'train/learning_rate': 2.7788426763110218e-05, 'train/global_step': 2186}


Epoch 1:  93%|█████████▎| 2061/2212 [11:09<00:49,  3.03batch/s, loss=0.522, lr=2.79e-5]

{'train/batch_loss': 0.5215580463409424, 'train/learning_rate': 2.7869801084990874e-05, 'train/global_step': 2196}


Epoch 1:  94%|█████████▎| 2071/2212 [11:12<00:40,  3.51batch/s, loss=0.844, lr=2.8e-5] 

{'train/batch_loss': 0.8438436985015869, 'train/learning_rate': 2.7951175406871523e-05, 'train/global_step': 2206}


Epoch 1:  94%|█████████▍| 2081/2212 [11:15<00:36,  3.58batch/s, loss=1.02, lr=2.8e-5] 

{'train/batch_loss': 1.0162566900253296, 'train/learning_rate': 2.8032549728752166e-05, 'train/global_step': 2216}


Epoch 1:  95%|█████████▍| 2091/2212 [11:18<00:34,  3.47batch/s, loss=0.837, lr=2.81e-5]

{'train/batch_loss': 0.83709317445755, 'train/learning_rate': 2.811392405063282e-05, 'train/global_step': 2226}


Epoch 1:  95%|█████████▍| 2101/2212 [11:22<00:37,  2.95batch/s, loss=0.696, lr=2.82e-5]

{'train/batch_loss': 0.6957387328147888, 'train/learning_rate': 2.819529837251347e-05, 'train/global_step': 2236}


Epoch 1:  95%|█████████▌| 2111/2212 [11:25<00:38,  2.66batch/s, loss=0.635, lr=2.83e-5]

{'train/batch_loss': 0.6345019340515137, 'train/learning_rate': 2.8276672694394117e-05, 'train/global_step': 2246}


Epoch 1:  96%|█████████▌| 2121/2212 [11:28<00:31,  2.93batch/s, loss=0.689, lr=2.84e-5]

{'train/batch_loss': 0.6885377764701843, 'train/learning_rate': 2.8358047016274763e-05, 'train/global_step': 2256}


Epoch 1:  96%|█████████▋| 2131/2212 [11:31<00:26,  3.02batch/s, loss=0.886, lr=2.84e-5]

{'train/batch_loss': 0.8864436149597168, 'train/learning_rate': 2.8439421338155412e-05, 'train/global_step': 2266}


Epoch 1:  97%|█████████▋| 2141/2212 [11:34<00:20,  3.51batch/s, loss=0.761, lr=2.85e-5]

{'train/batch_loss': 0.7608665227890015, 'train/learning_rate': 2.852079566003606e-05, 'train/global_step': 2276}


Epoch 1:  97%|█████████▋| 2151/2212 [11:37<00:19,  3.15batch/s, loss=0.996, lr=2.86e-5]

{'train/batch_loss': 0.9961159229278564, 'train/learning_rate': 2.860216998191672e-05, 'train/global_step': 2286}


Epoch 1:  98%|█████████▊| 2161/2212 [11:41<00:14,  3.45batch/s, loss=0.619, lr=2.87e-5]

{'train/batch_loss': 0.6193861365318298, 'train/learning_rate': 2.8683544303797367e-05, 'train/global_step': 2296}


Epoch 1:  98%|█████████▊| 2171/2212 [11:44<00:14,  2.92batch/s, loss=0.662, lr=2.88e-5]

{'train/batch_loss': 0.6618478298187256, 'train/learning_rate': 2.8764918625678016e-05, 'train/global_step': 2306}


Epoch 1:  99%|█████████▊| 2181/2212 [11:48<00:11,  2.73batch/s, loss=1.02, lr=2.88e-5] 

{'train/batch_loss': 1.017358660697937, 'train/learning_rate': 2.8846292947558665e-05, 'train/global_step': 2316}


Epoch 1:  99%|█████████▉| 2191/2212 [11:51<00:05,  3.50batch/s, loss=0.871, lr=2.89e-5]

{'train/batch_loss': 0.871448814868927, 'train/learning_rate': 2.8927667269439324e-05, 'train/global_step': 2326}


Epoch 1: 100%|█████████▉| 2201/2212 [11:54<00:03,  2.97batch/s, loss=0.434, lr=2.9e-5] 

{'train/batch_loss': 0.4335782527923584, 'train/learning_rate': 2.9009041591319984e-05, 'train/global_step': 2336}


Epoch 1: 100%|█████████▉| 2211/2212 [11:57<00:00,  3.54batch/s, loss=0.91, lr=2.91e-5] 

{'train/batch_loss': 0.9099376797676086, 'train/learning_rate': 2.9090415913200643e-05, 'train/global_step': 2346}


Epoch 1: 100%|██████████| 2212/2212 [11:57<00:00,  3.08batch/s, loss=1.41, lr=2.91e-5]


In [33]:
batch['ehr'].shape, batch['prev_cxr'].shape

(torch.Size([2, 100, 81]), torch.Size([2, 100, 512]))

In [35]:
torch.cat([batch['ehr'],  batch['prev_cxr']], dim = -1).shape

torch.Size([2, 100, 593])

In [31]:
inf_mask = torch.isinf(batch['ehr'])
inf_indices = torch.where(inf_mask)
num_infs = inf_mask.sum().item()

if inf_mask.any():
    # For a multi-dimensional tensor
    for idx in zip(*torch.where(inf_mask)):
        print(f"Inf found at index: {idx}")

Inf found at index: (tensor(3), tensor(64), tensor(46))
Inf found at index: (tensor(3), tensor(65), tensor(46))
Inf found at index: (tensor(3), tensor(66), tensor(46))
Inf found at index: (tensor(3), tensor(67), tensor(46))
Inf found at index: (tensor(3), tensor(68), tensor(46))
Inf found at index: (tensor(3), tensor(69), tensor(46))
Inf found at index: (tensor(3), tensor(70), tensor(46))
Inf found at index: (tensor(3), tensor(71), tensor(46))
Inf found at index: (tensor(3), tensor(72), tensor(46))
Inf found at index: (tensor(3), tensor(73), tensor(46))
Inf found at index: (tensor(3), tensor(74), tensor(46))
Inf found at index: (tensor(3), tensor(75), tensor(46))
Inf found at index: (tensor(3), tensor(76), tensor(46))
Inf found at index: (tensor(3), tensor(77), tensor(46))
Inf found at index: (tensor(3), tensor(78), tensor(46))
Inf found at index: (tensor(3), tensor(79), tensor(46))
Inf found at index: (tensor(3), tensor(80), tensor(46))
Inf found at index: (tensor(3), tensor(81), tens

In [None]:
batch['ehr'][1, :, 45]

In [36]:
df = pd.read_pickle(supertable_path / (batch_encounters[1] + '.pickle'))

In [43]:
df.values[:, 45]

array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, -0.5018181818181818, -0.5018181818181818, -0.505,
       -0.505, -0.505, -0.505, -0.505, -0.505, -0.505, -0.25, -0.25,
       -0.25, -0.25, -0.5, -0.5, -0.2515625, -0.2515625, -0.2515625,
       -0.2515625, -0.2550000000000001, -0.2550000000000001,
       -0.2550000000000001, -0.2575, -0.253125, -0.253125, -0.253125,
       -0.253125, -0.253125, -0.2625, -0.2625, -0.2625, -0.2625,
       -0.2666666666666667, -0.2666666666666667, -0.2666666666666667,
       -0.26875, -0.2583333333333334, -0.2583333333333334,
       -0.2583333333333334, -0.2583333333333334, -0.2583333333333334,
       -0.2625, -0.2625, -0.2625, -0.2666666666666667,
       -0.2666666666666667, -0.2666666666666667, -0.275, -0.26875,
       -0.26875, -0.26875, -0.26875, -0.2625, -0.2625, -0.2625, -0.2625,
       -0.25, -0.25, -0.2583333333333334, -0.25833333333

In [30]:
outputs = model(
                    ehr=batch["ehr"],
                    prev_cxr=batch["prev_cxr"],
                    target_input=None,
                    encoder_attention_mask=batch.get("attention_mask"),
                    decoder_attention_mask=batch.get("attention_mask"),
                    causal_mask=True, debug = True, batch_idx_debug = 1
                )

EHR input
tensor([[0.9000, 1.8933, 1.2533,  ..., 0.0000, 1.0000, 0.0000],
        [1.0100, 1.8933, 1.7400,  ..., 0.0000, 1.0000, 0.0000],
        [1.0500, 1.8933, 1.8000,  ..., 0.0000, 1.0000, 0.0000],
        ...,
        [0.5000, 1.7533, 1.4000,  ..., 0.0000, 1.0000, 0.0000],
        [0.5000, 1.7533, 1.4000,  ..., 0.0000, 1.0000, 0.0000],
        [0.5000, 1.7533, 1.6000,  ..., 0.0000, 1.0000, 0.0000]])
ehr - min: -3.4000000953674316, max: inf, has_nan: False
EHR embedding
tensor([[ 0.3397, -2.5103, -1.9462,  ..., -0.2305, -1.0521, -0.8516],
        [ 0.3448, -2.5688, -1.9900,  ..., -0.2090, -1.0538, -0.8895],
        [ 0.3552, -2.6524, -2.0044,  ..., -0.2060, -1.0753, -0.9462],
        ...,
        [-0.9441, -4.9256, -2.1909,  ...,  2.2478, -3.2096, -0.9323],
        [-0.9532, -4.9645, -2.2049,  ...,  2.2638, -3.2335, -0.9412],
        [-0.9052, -5.0178, -2.2416,  ...,  2.2919, -3.2731, -0.9600]],
       grad_fn=<SliceBackward0>)
x - min: -inf, max: inf, has_nan: False
CXR condition
