In [1]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import os
import pickle
import itertools
import time

import wandb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, RandomSampler, random_split

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorWithPadding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_MODEL = "deepseek-ai/deepseek-coder-1.3b-base"
STORAGE_DIR = "/proj/rcs-hdd/aj3051/symmetry"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# with open(os.path.join(STORAGE_DIR, 'data_four_permutations.pickle'), 'rb') as f:
#     dataset_raw = pickle.load(f)

In [5]:
# def tokenize_dataset_individual(dataset, max_length=1024, tokenizer="deepseek-ai/deepseek-coder-1.3b-base"):
#     tokenizer = AutoTokenizer.from_pretrained(tokenizer, cache_dir="/proj/rcs-hdd/aj3051/hf_transformers")
#     res = []
#     for code, line_permutation_orders, label in tqdm(dataset):
#         tokenized_loc = [ 
#             tokenizer(line_text, return_tensors="pt", add_special_tokens=(line_num==0))['input_ids'][0]
#             for line_num, line_text in enumerate(code)
#         ]
#         input = torch.cat(tokenized_loc) # unpermuted code input
#         if len(input) > max_length:
#             continue 

#         label = tokenizer(label, return_tensors="pt", add_special_tokens=False)['input_ids'][0][0].unsqueeze(dim=0)
#         sample = (input.to(dtype=torch.long), label)
#         res.append(sample)

#     return res

# dataset = tokenize_dataset_individual(dataset=dataset_raw, max_length=1024, tokenizer="deepseek-ai/deepseek-coder-1.3b-base")

# with open(os.path.join(STORAGE_DIR, 'data_individual.pkl'), 'wb') as f:
#     pickle.dump(dataset, f)

In [4]:
with open(os.path.join(STORAGE_DIR, 'data_individual.pkl'), 'rb') as f:
    code_individual_dataset = pickle.load(f)

In [5]:
class CodeDataset(Dataset):

    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        input_ids, label = self.dataset[idx] 
        return {
            'input_ids': input_ids
            'labels': torch.tensor([label])
        }


In [6]:
dataset = CodeDataset(dataset=code_individual_dataset)

sampler = RandomSampler(dataset)
shuffled_indices = list(sampler)
shuffled_dataset = torch.utils.data.Subset(dataset, shuffled_indices)
train_length = int(0.8 * len(shuffled_dataset))
validation_length = int(0.1 * len(shuffled_dataset))
test_length = len(shuffled_dataset) - train_length - validation_length

train_set, validation_set, test_set = random_split(
    shuffled_dataset, [train_length, validation_length, test_length]
)

# train_loader = torch.utils.data.DataLoader(train_set, batch_size=None, shuffle=False)
# validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=None, shuffle=False)
# test_loader = torch.utils.data.DataLoader(test_set, batch_size=None, shuffle=False)

In [7]:
model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        cache_dir="/proj/rcs-hdd/aj3051/hf_transformers",
).to(device)

In [8]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, cache_dir="/proj/rcs-hdd/aj3051/hf_transformers")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
training_args = TrainingArguments(
    output_dir='./results',         
    num_train_epochs=3,             
    per_device_train_batch_size=1,  
    per_device_eval_batch_size=1, 
    report_to=None, 
)

In [10]:
first_batch = None

In [11]:
class CustomTrainer(Trainer):
    def training_step(self, model, inputs):
        print("Input IDs shape:", inputs["input_ids"].shape)
        print("Labels shape:", inputs["labels"].shape)
        global first_batch
        first_batch = inputs["input_ids"]
        # ... rest of the code ...
        return super().training_step(model, inputs)        

In [13]:
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=validation_set,
    data_collator=data_collator,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [14]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maj3051[0m. Use [1m`wandb login --relogin`[0m to force relogin


Input IDs shape: torch.Size([1, 145])
Labels shape: torch.Size([1, 1])


ValueError: Expected input batch_size (144) to match target batch_size (0).

In [68]:
first_batch[0]

tensor([32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014, 32014,
        32014, 32014, 32014, 32013, 13198,  7050, 15545,  8440,   821,  6337,
          334,  1232,  1753, 15545,  8440,  9828,  2039,  1232, 

In [10]:
lr = 1e-5
num_epochs = 5

In [11]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

eval_every = 3000
checkpoint_every = 2000


model.train()
step = 0
for epoch in range(num_epochs):  # Train for 5 epochs
    total_loss = 0
    # for i, batch in enumerate(tqdm(dataloader)):
    for epoch_i, batch in tqdm(enumerate(train_loader)):
        input_ids, label = batch
        input_ids = input_ids.unsqueeze(dim=0).to(device)
        label = label.to(device)
        # print(longest_len)
        optimizer.zero_grad()
        print(f"input shape: {input_ids.shape}")
        output = model(input_ids)
        logits = output.logits[:, -1, :]

        loss = loss_fn(logits, label)
        loss.backward()
        optimizer.step()

        step += 1
        # if step % eval_every == 0:
        #     val_loss = eval()
        #     wandb.log({"validation loss: ", val_loss})
        # if step % checkpoint_every == 0:
        #     save_dir = f"{MODEL_DIR}/model_{step}"
        #     torch.save({
        #         'epoch': epoch, 
        #         'model_state_dict': equitune_model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'loss': loss
            # }, save_dir)
        total_loss += loss.item()
        print(f'loss: {loss}')
        # wandb.log({"loss": loss.item()})

0it [00:00, ?it/s]

input shape: torch.Size([1, 173])


2it [00:01,  1.55it/s]

loss: 8.689229011535645
input shape: torch.Size([1, 567])
loss: nan
input shape: torch.Size([1, 205])


4it [00:01,  3.20it/s]

loss: nan
input shape: torch.Size([1, 531])
loss: nan
input shape: torch.Size([1, 109])


6it [00:02,  4.55it/s]

loss: nan
input shape: torch.Size([1, 528])
loss: nan
input shape: torch.Size([1, 173])


8it [00:02,  5.72it/s]

loss: nan
input shape: torch.Size([1, 288])
loss: nan
input shape: torch.Size([1, 158])


10it [00:02,  6.10it/s]

loss: nan
input shape: torch.Size([1, 614])
loss: nan
input shape: torch.Size([1, 117])


12it [00:02,  6.62it/s]

loss: nan
input shape: torch.Size([1, 356])
loss: nan
input shape: torch.Size([1, 206])


14it [00:03,  6.80it/s]

loss: nan
input shape: torch.Size([1, 223])
loss: nan
input shape: torch.Size([1, 507])


16it [00:03,  6.85it/s]

loss: nan
input shape: torch.Size([1, 350])
loss: nan
input shape: torch.Size([1, 152])


18it [00:03,  6.97it/s]

loss: nan
input shape: torch.Size([1, 379])
loss: nan
input shape: torch.Size([1, 268])


20it [00:04,  7.14it/s]

loss: nan
input shape: torch.Size([1, 143])
loss: nan
input shape: torch.Size([1, 190])


22it [00:04,  7.04it/s]

loss: nan
input shape: torch.Size([1, 423])
loss: nan
input shape: torch.Size([1, 285])


24it [00:04,  7.14it/s]

loss: nan
input shape: torch.Size([1, 342])
loss: nan
input shape: torch.Size([1, 516])


26it [00:04,  7.00it/s]

loss: nan
input shape: torch.Size([1, 233])
loss: nan
input shape: torch.Size([1, 189])


28it [00:05,  7.07it/s]

loss: nan
input shape: torch.Size([1, 374])
loss: nan
input shape: torch.Size([1, 255])


30it [00:05,  7.13it/s]

loss: nan
input shape: torch.Size([1, 232])
loss: nan
input shape: torch.Size([1, 457])


32it [00:05,  6.91it/s]

loss: nan
input shape: torch.Size([1, 376])
loss: nan
input shape: torch.Size([1, 388])


34it [00:06,  6.65it/s]

loss: nan
input shape: torch.Size([1, 109])
loss: nan
input shape: torch.Size([1, 519])


36it [00:06,  6.44it/s]

loss: nan
input shape: torch.Size([1, 421])
loss: nan
input shape: torch.Size([1, 580])


38it [00:06,  6.46it/s]

loss: nan
input shape: torch.Size([1, 308])
loss: nan
input shape: torch.Size([1, 292])


40it [00:07,  6.53it/s]

loss: nan
input shape: torch.Size([1, 82])
loss: nan
input shape: torch.Size([1, 409])


42it [00:07,  6.56it/s]

loss: nan
input shape: torch.Size([1, 324])
loss: nan
input shape: torch.Size([1, 693])


44it [00:07,  6.19it/s]

loss: nan
input shape: torch.Size([1, 282])
loss: nan
input shape: torch.Size([1, 152])


46it [00:08,  5.98it/s]

loss: nan
input shape: torch.Size([1, 677])
loss: nan
input shape: torch.Size([1, 364])


48it [00:08,  6.23it/s]

loss: nan
input shape: torch.Size([1, 304])
loss: nan
input shape: torch.Size([1, 208])


50it [00:08,  6.45it/s]

loss: nan
input shape: torch.Size([1, 141])
loss: nan
input shape: torch.Size([1, 519])


52it [00:08,  6.48it/s]

loss: nan
input shape: torch.Size([1, 354])
loss: nan
input shape: torch.Size([1, 261])


54it [00:09,  6.48it/s]

loss: nan
input shape: torch.Size([1, 305])
loss: nan
input shape: torch.Size([1, 141])


56it [00:09,  6.51it/s]

loss: nan
input shape: torch.Size([1, 485])
loss: nan
input shape: torch.Size([1, 218])


58it [00:09,  6.61it/s]

loss: nan
input shape: torch.Size([1, 109])
loss: nan
input shape: torch.Size([1, 218])


60it [00:10,  6.61it/s]

loss: nan
input shape: torch.Size([1, 314])
loss: nan
input shape: torch.Size([1, 415])


62it [00:10,  6.58it/s]

loss: nan
input shape: torch.Size([1, 384])
loss: nan
input shape: torch.Size([1, 180])


63it [00:10,  6.62it/s]

loss: nan
input shape: torch.Size([1, 38])


65it [00:11,  3.62it/s]

loss: nan
input shape: torch.Size([1, 162])
loss: nan
input shape: torch.Size([1, 401])


67it [00:11,  4.69it/s]

loss: nan
input shape: torch.Size([1, 182])
loss: nan
input shape: torch.Size([1, 293])


69it [00:12,  5.52it/s]

loss: nan
input shape: torch.Size([1, 145])
loss: nan
input shape: torch.Size([1, 289])


71it [00:12,  5.66it/s]

loss: nan
input shape: torch.Size([1, 702])
loss: nan
input shape: torch.Size([1, 259])


73it [00:12,  6.10it/s]

loss: nan
input shape: torch.Size([1, 260])
loss: nan
input shape: torch.Size([1, 367])


75it [00:13,  6.22it/s]

loss: nan
input shape: torch.Size([1, 486])
loss: nan
input shape: torch.Size([1, 358])


77it [00:13,  6.40it/s]

loss: nan
input shape: torch.Size([1, 382])
loss: nan
input shape: torch.Size([1, 429])


79it [00:13,  6.46it/s]

loss: nan
input shape: torch.Size([1, 225])
loss: nan
input shape: torch.Size([1, 210])


81it [00:13,  6.50it/s]

loss: nan
input shape: torch.Size([1, 294])
loss: nan
input shape: torch.Size([1, 331])


83it [00:14,  6.54it/s]

loss: nan
input shape: torch.Size([1, 216])
loss: nan
input shape: torch.Size([1, 242])


85it [00:14,  6.57it/s]

loss: nan
input shape: torch.Size([1, 372])
loss: nan
input shape: torch.Size([1, 117])


87it [00:14,  6.58it/s]

loss: nan
input shape: torch.Size([1, 94])
loss: nan
input shape: torch.Size([1, 236])


89it [00:15,  6.65it/s]

loss: nan
input shape: torch.Size([1, 243])
loss: nan
input shape: torch.Size([1, 88])


91it [00:15,  6.61it/s]

loss: nan
input shape: torch.Size([1, 502])
loss: nan
input shape: torch.Size([1, 386])


93it [00:15,  6.57it/s]

loss: nan
input shape: torch.Size([1, 282])
loss: nan
input shape: torch.Size([1, 210])


95it [00:16,  6.73it/s]

loss: nan
input shape: torch.Size([1, 353])
loss: nan
input shape: torch.Size([1, 325])


97it [00:16,  6.11it/s]

loss: nan
input shape: torch.Size([1, 725])
loss: nan
input shape: torch.Size([1, 286])


99it [00:16,  6.04it/s]

loss: nan
input shape: torch.Size([1, 658])
loss: nan
input shape: torch.Size([1, 146])


101it [00:17,  6.26it/s]

loss: nan
input shape: torch.Size([1, 485])
loss: nan
input shape: torch.Size([1, 523])


103it [00:17,  6.49it/s]

loss: nan
input shape: torch.Size([1, 382])
loss: nan
input shape: torch.Size([1, 71])


105it [00:17,  6.69it/s]

loss: nan
input shape: torch.Size([1, 235])
loss: nan
input shape: torch.Size([1, 408])


107it [00:17,  6.69it/s]

loss: nan
input shape: torch.Size([1, 173])
loss: nan
input shape: torch.Size([1, 187])


109it [00:18,  6.44it/s]

loss: nan
input shape: torch.Size([1, 551])
loss: nan
input shape: torch.Size([1, 497])


111it [00:18,  6.46it/s]

loss: nan
input shape: torch.Size([1, 170])
loss: nan
input shape: torch.Size([1, 145])


113it [00:18,  6.64it/s]

loss: nan
input shape: torch.Size([1, 239])
loss: nan
input shape: torch.Size([1, 243])


115it [00:19,  6.59it/s]

loss: nan
input shape: torch.Size([1, 469])
loss: nan
input shape: torch.Size([1, 155])


117it [00:19,  6.67it/s]

loss: nan
input shape: torch.Size([1, 170])
loss: nan
input shape: torch.Size([1, 213])


119it [00:19,  6.50it/s]

loss: nan
input shape: torch.Size([1, 502])
loss: nan
input shape: torch.Size([1, 243])


121it [00:20,  6.46it/s]

loss: nan
input shape: torch.Size([1, 492])
loss: nan
input shape: torch.Size([1, 100])


123it [00:20,  6.00it/s]

loss: nan
input shape: torch.Size([1, 760])
loss: nan
input shape: torch.Size([1, 136])


125it [00:20,  6.33it/s]

loss: nan
input shape: torch.Size([1, 305])
loss: nan
input shape: torch.Size([1, 142])


127it [00:21,  6.52it/s]

loss: nan
input shape: torch.Size([1, 226])
loss: nan
input shape: torch.Size([1, 179])


129it [00:21,  6.66it/s]

loss: nan
input shape: torch.Size([1, 121])
loss: nan
input shape: torch.Size([1, 311])


131it [00:21,  6.68it/s]

loss: nan
input shape: torch.Size([1, 235])
loss: nan
input shape: torch.Size([1, 433])


133it [00:21,  6.63it/s]

loss: nan
input shape: torch.Size([1, 371])
loss: nan
input shape: torch.Size([1, 562])


135it [00:22,  6.62it/s]

loss: nan
input shape: torch.Size([1, 253])
loss: nan
input shape: torch.Size([1, 236])


137it [00:22,  6.62it/s]

loss: nan
input shape: torch.Size([1, 494])
loss: nan
input shape: torch.Size([1, 204])


139it [00:22,  6.57it/s]

loss: nan
input shape: torch.Size([1, 242])
loss: nan
input shape: torch.Size([1, 301])


141it [00:23,  6.65it/s]

loss: nan
input shape: torch.Size([1, 338])
loss: nan
input shape: torch.Size([1, 358])


143it [00:23,  6.53it/s]

loss: nan
input shape: torch.Size([1, 412])
loss: nan
input shape: torch.Size([1, 245])


145it [00:23,  6.63it/s]

loss: nan
input shape: torch.Size([1, 337])
loss: nan
input shape: torch.Size([1, 288])


147it [00:24,  6.87it/s]

loss: nan
input shape: torch.Size([1, 297])
loss: nan
input shape: torch.Size([1, 341])


149it [00:24,  7.04it/s]

loss: nan
input shape: torch.Size([1, 162])
loss: nan
input shape: torch.Size([1, 379])


151it [00:24,  6.96it/s]

loss: nan
input shape: torch.Size([1, 255])
loss: nan
input shape: torch.Size([1, 355])


153it [00:24,  6.92it/s]

loss: nan
input shape: torch.Size([1, 416])
loss: nan
input shape: torch.Size([1, 254])


155it [00:25,  7.00it/s]

loss: nan
input shape: torch.Size([1, 241])
loss: nan
input shape: torch.Size([1, 247])


157it [00:25,  7.03it/s]

loss: nan
input shape: torch.Size([1, 227])
loss: nan
input shape: torch.Size([1, 460])


159it [00:25,  6.96it/s]

loss: nan
input shape: torch.Size([1, 327])
loss: nan
input shape: torch.Size([1, 304])


161it [00:26,  7.09it/s]

loss: nan
input shape: torch.Size([1, 203])
loss: nan
input shape: torch.Size([1, 610])


163it [00:26,  6.64it/s]

loss: nan
input shape: torch.Size([1, 399])
loss: nan
input shape: torch.Size([1, 488])


165it [00:26,  6.40it/s]

loss: nan
input shape: torch.Size([1, 514])
loss: nan
input shape: torch.Size([1, 697])


167it [00:27,  6.37it/s]

loss: nan
input shape: torch.Size([1, 148])
loss: nan
input shape: torch.Size([1, 245])


169it [00:27,  6.74it/s]

loss: nan
input shape: torch.Size([1, 339])
loss: nan
input shape: torch.Size([1, 259])


171it [00:27,  6.99it/s]

loss: nan
input shape: torch.Size([1, 114])
loss: nan
input shape: torch.Size([1, 347])


173it [00:27,  6.85it/s]

loss: nan
input shape: torch.Size([1, 482])
loss: nan
input shape: torch.Size([1, 134])


175it [00:28,  7.03it/s]

loss: nan
input shape: torch.Size([1, 263])
loss: nan
input shape: torch.Size([1, 160])


177it [00:28,  7.07it/s]

loss: nan
input shape: torch.Size([1, 364])
loss: nan
input shape: torch.Size([1, 396])


179it [00:28,  6.97it/s]

loss: nan
input shape: torch.Size([1, 191])
loss: nan
input shape: torch.Size([1, 522])


181it [00:29,  6.91it/s]

loss: nan
input shape: torch.Size([1, 155])
loss: nan
input shape: torch.Size([1, 109])


183it [00:29,  6.99it/s]

loss: nan
input shape: torch.Size([1, 330])
loss: nan
input shape: torch.Size([1, 327])


185it [00:29,  7.12it/s]

loss: nan
input shape: torch.Size([1, 256])
loss: nan
input shape: torch.Size([1, 547])


187it [00:29,  6.86it/s]

loss: nan
input shape: torch.Size([1, 87])
loss: nan
input shape: torch.Size([1, 371])


189it [00:30,  7.01it/s]

loss: nan
input shape: torch.Size([1, 254])
loss: nan
input shape: torch.Size([1, 109])


191it [00:30,  7.10it/s]

loss: nan
input shape: torch.Size([1, 219])
loss: nan
input shape: torch.Size([1, 468])


193it [00:30,  6.97it/s]

loss: nan
input shape: torch.Size([1, 120])
loss: nan
input shape: torch.Size([1, 205])


195it [00:31,  7.10it/s]

loss: nan
input shape: torch.Size([1, 324])
loss: nan
input shape: torch.Size([1, 686])


197it [00:31,  6.70it/s]

loss: nan
input shape: torch.Size([1, 302])
loss: nan
input shape: torch.Size([1, 783])


199it [00:31,  6.12it/s]

loss: nan
input shape: torch.Size([1, 389])
loss: nan
input shape: torch.Size([1, 561])


201it [00:32,  5.78it/s]

loss: nan
input shape: torch.Size([1, 722])
loss: nan
input shape: torch.Size([1, 370])


203it [00:32,  6.41it/s]

loss: nan
input shape: torch.Size([1, 305])
loss: nan
input shape: torch.Size([1, 553])


205it [00:32,  6.58it/s]

loss: nan
input shape: torch.Size([1, 37])
loss: nan
input shape: torch.Size([1, 478])


207it [00:32,  6.75it/s]

loss: nan
input shape: torch.Size([1, 339])
loss: nan
input shape: torch.Size([1, 217])


209it [00:33,  7.01it/s]

loss: nan
input shape: torch.Size([1, 277])
loss: nan
input shape: torch.Size([1, 209])


211it [00:33,  6.75it/s]

loss: nan
input shape: torch.Size([1, 577])
loss: nan
input shape: torch.Size([1, 435])


213it [00:33,  6.85it/s]

loss: nan
input shape: torch.Size([1, 261])
loss: nan
input shape: torch.Size([1, 242])


215it [00:34,  7.05it/s]

loss: nan
input shape: torch.Size([1, 154])
loss: nan
input shape: torch.Size([1, 172])


217it [00:34,  6.93it/s]

loss: nan
input shape: torch.Size([1, 444])
loss: nan
input shape: torch.Size([1, 317])


219it [00:34,  7.03it/s]

loss: nan
input shape: torch.Size([1, 356])
loss: nan
input shape: torch.Size([1, 382])


221it [00:34,  7.07it/s]

loss: nan
input shape: torch.Size([1, 373])
loss: nan
input shape: torch.Size([1, 291])


223it [00:35,  7.14it/s]

loss: nan
input shape: torch.Size([1, 193])
loss: nan
input shape: torch.Size([1, 431])


225it [00:35,  6.45it/s]

loss: nan
input shape: torch.Size([1, 670])
loss: nan
input shape: torch.Size([1, 424])


227it [00:35,  6.74it/s]

loss: nan
input shape: torch.Size([1, 158])
loss: nan
input shape: torch.Size([1, 236])


229it [00:36,  6.97it/s]

loss: nan
input shape: torch.Size([1, 345])
loss: nan
input shape: torch.Size([1, 410])


231it [00:36,  6.75it/s]

loss: nan
input shape: torch.Size([1, 505])
loss: nan
input shape: torch.Size([1, 239])


233it [00:36,  7.00it/s]

loss: nan
input shape: torch.Size([1, 250])
loss: nan
input shape: torch.Size([1, 93])


235it [00:36,  7.08it/s]

loss: nan
input shape: torch.Size([1, 367])
loss: nan
input shape: torch.Size([1, 285])


237it [00:37,  6.90it/s]

loss: nan
input shape: torch.Size([1, 501])
loss: nan
input shape: torch.Size([1, 609])


239it [00:37,  6.48it/s]

loss: nan
input shape: torch.Size([1, 558])
loss: nan
input shape: torch.Size([1, 229])


241it [00:37,  6.57it/s]

loss: nan
input shape: torch.Size([1, 528])
loss: nan
input shape: torch.Size([1, 216])


243it [00:38,  6.84it/s]

loss: nan
input shape: torch.Size([1, 229])
loss: nan
input shape: torch.Size([1, 212])


245it [00:38,  6.93it/s]

loss: nan
input shape: torch.Size([1, 178])
loss: nan
input shape: torch.Size([1, 338])


247it [00:38,  7.07it/s]

loss: nan
input shape: torch.Size([1, 326])
loss: nan
input shape: torch.Size([1, 134])


249it [00:39,  6.84it/s]

loss: nan
input shape: torch.Size([1, 453])
loss: nan
input shape: torch.Size([1, 485])


251it [00:39,  6.50it/s]

loss: nan
input shape: torch.Size([1, 387])
loss: nan
input shape: torch.Size([1, 205])


253it [00:39,  6.42it/s]

loss: nan
input shape: torch.Size([1, 476])
loss: nan
input shape: torch.Size([1, 481])


255it [00:39,  6.51it/s]

loss: nan
input shape: torch.Size([1, 147])
loss: nan
input shape: torch.Size([1, 516])


257it [00:40,  6.50it/s]

loss: nan
input shape: torch.Size([1, 292])
loss: nan
input shape: torch.Size([1, 467])


259it [00:40,  6.50it/s]

loss: nan
input shape: torch.Size([1, 155])
loss: nan
input shape: torch.Size([1, 285])


261it [00:40,  6.40it/s]

loss: nan
input shape: torch.Size([1, 520])
loss: nan
input shape: torch.Size([1, 286])


263it [00:41,  6.53it/s]

loss: nan
input shape: torch.Size([1, 306])
loss: nan
input shape: torch.Size([1, 431])


265it [00:41,  6.52it/s]

loss: nan
input shape: torch.Size([1, 419])
loss: nan
input shape: torch.Size([1, 204])


267it [00:41,  6.49it/s]

loss: nan
input shape: torch.Size([1, 260])
loss: nan
input shape: torch.Size([1, 185])


269it [00:42,  6.50it/s]

loss: nan
input shape: torch.Size([1, 390])
loss: nan
input shape: torch.Size([1, 556])


271it [00:42,  6.47it/s]

loss: nan
input shape: torch.Size([1, 164])
loss: nan
input shape: torch.Size([1, 152])


273it [00:42,  6.57it/s]

loss: nan
input shape: torch.Size([1, 148])
loss: nan
input shape: torch.Size([1, 92])


275it [00:43,  6.42it/s]

loss: nan
input shape: torch.Size([1, 540])
loss: nan
input shape: torch.Size([1, 364])


277it [00:43,  6.60it/s]

loss: nan
input shape: torch.Size([1, 322])
loss: nan
input shape: torch.Size([1, 197])


279it [00:43,  6.40it/s]

loss: nan
input shape: torch.Size([1, 584])
loss: nan
input shape: torch.Size([1, 478])


281it [00:43,  6.35it/s]

loss: nan
input shape: torch.Size([1, 546])
loss: nan
input shape: torch.Size([1, 198])


283it [00:44,  6.12it/s]

loss: nan
input shape: torch.Size([1, 658])
loss: nan
input shape: torch.Size([1, 381])


285it [00:44,  6.34it/s]

loss: nan
input shape: torch.Size([1, 417])
loss: nan
input shape: torch.Size([1, 147])


287it [00:44,  6.46it/s]

loss: nan
input shape: torch.Size([1, 212])
loss: nan
input shape: torch.Size([1, 579])


289it [00:45,  6.26it/s]

loss: nan
input shape: torch.Size([1, 582])
loss: nan
input shape: torch.Size([1, 261])


291it [00:45,  6.45it/s]

loss: nan
input shape: torch.Size([1, 248])
loss: nan
input shape: torch.Size([1, 306])


293it [00:45,  6.59it/s]

loss: nan
input shape: torch.Size([1, 220])
loss: nan
input shape: torch.Size([1, 405])


295it [00:46,  6.69it/s]

loss: nan
input shape: torch.Size([1, 137])
loss: nan
input shape: torch.Size([1, 134])


297it [00:46,  6.70it/s]

loss: nan
input shape: torch.Size([1, 392])
loss: nan
input shape: torch.Size([1, 209])


299it [00:46,  6.76it/s]

loss: nan
input shape: torch.Size([1, 313])
loss: nan
input shape: torch.Size([1, 362])


301it [00:47,  6.86it/s]

loss: nan
input shape: torch.Size([1, 148])
loss: nan
input shape: torch.Size([1, 457])


303it [00:47,  6.82it/s]

loss: nan
input shape: torch.Size([1, 137])
loss: nan
input shape: torch.Size([1, 58])


305it [00:47,  6.73it/s]

loss: nan
input shape: torch.Size([1, 38])
loss: nan
input shape: torch.Size([1, 422])


307it [00:47,  6.85it/s]

loss: nan
input shape: torch.Size([1, 148])
loss: nan
input shape: torch.Size([1, 135])


309it [00:48,  6.74it/s]

loss: nan
input shape: torch.Size([1, 286])
loss: nan
input shape: torch.Size([1, 292])


311it [00:48,  6.68it/s]

loss: nan
input shape: torch.Size([1, 243])
loss: nan
input shape: torch.Size([1, 369])


313it [00:48,  6.74it/s]

loss: nan
input shape: torch.Size([1, 382])
loss: nan
input shape: torch.Size([1, 231])


315it [00:49,  6.73it/s]

loss: nan
input shape: torch.Size([1, 109])
loss: nan
input shape: torch.Size([1, 256])


317it [00:49,  6.72it/s]

loss: nan
input shape: torch.Size([1, 225])
loss: nan
input shape: torch.Size([1, 439])


319it [00:49,  6.65it/s]

loss: nan
input shape: torch.Size([1, 274])
loss: nan
input shape: torch.Size([1, 489])


321it [00:50,  6.61it/s]

loss: nan
input shape: torch.Size([1, 422])
loss: nan
input shape: torch.Size([1, 489])


323it [00:50,  6.47it/s]

loss: nan
input shape: torch.Size([1, 555])
loss: nan
input shape: torch.Size([1, 118])


325it [00:50,  6.50it/s]

loss: nan
input shape: torch.Size([1, 445])
loss: nan
input shape: torch.Size([1, 226])


327it [00:50,  6.56it/s]

loss: nan
input shape: torch.Size([1, 490])
loss: nan
input shape: torch.Size([1, 314])


329it [00:51,  6.79it/s]

loss: nan
input shape: torch.Size([1, 345])
loss: nan
input shape: torch.Size([1, 326])


331it [00:51,  6.67it/s]

loss: nan
input shape: torch.Size([1, 536])
loss: nan
input shape: torch.Size([1, 363])


333it [00:51,  6.75it/s]

loss: nan
input shape: torch.Size([1, 205])
loss: nan
input shape: torch.Size([1, 282])


335it [00:52,  6.75it/s]

loss: nan
input shape: torch.Size([1, 275])
loss: nan
input shape: torch.Size([1, 332])


337it [00:52,  6.81it/s]

loss: nan
input shape: torch.Size([1, 197])
loss: nan
input shape: torch.Size([1, 171])


339it [00:52,  6.67it/s]

loss: nan
input shape: torch.Size([1, 155])
loss: nan
input shape: torch.Size([1, 192])


341it [00:53,  6.70it/s]

loss: nan
input shape: torch.Size([1, 237])
loss: nan
input shape: torch.Size([1, 109])


343it [00:53,  6.67it/s]

loss: nan
input shape: torch.Size([1, 158])
loss: nan
input shape: torch.Size([1, 231])


345it [00:53,  6.75it/s]

loss: nan
input shape: torch.Size([1, 290])
loss: nan
input shape: torch.Size([1, 475])


347it [00:53,  6.69it/s]

loss: nan
input shape: torch.Size([1, 102])
loss: nan
input shape: torch.Size([1, 339])


349it [00:54,  6.66it/s]

loss: nan
input shape: torch.Size([1, 326])
loss: nan
input shape: torch.Size([1, 234])


351it [00:54,  6.72it/s]

loss: nan
input shape: torch.Size([1, 169])
loss: nan
input shape: torch.Size([1, 447])


353it [00:54,  6.18it/s]

loss: nan
input shape: torch.Size([1, 710])
loss: nan
input shape: torch.Size([1, 192])


355it [00:55,  6.03it/s]

loss: nan
input shape: torch.Size([1, 704])
loss: nan
input shape: torch.Size([1, 243])


357it [00:55,  6.21it/s]

loss: nan
input shape: torch.Size([1, 478])
loss: nan
input shape: torch.Size([1, 429])


359it [00:55,  6.35it/s]

loss: nan
input shape: torch.Size([1, 184])
loss: nan
input shape: torch.Size([1, 687])


361it [00:56,  6.24it/s]

loss: nan
input shape: torch.Size([1, 203])
loss: nan
input shape: torch.Size([1, 339])


363it [00:56,  6.45it/s]

loss: nan
input shape: torch.Size([1, 432])
loss: nan
input shape: torch.Size([1, 136])


365it [00:56,  6.59it/s]

loss: nan
input shape: torch.Size([1, 329])
loss: nan
input shape: torch.Size([1, 614])


367it [00:57,  6.54it/s]

loss: nan
input shape: torch.Size([1, 268])
loss: nan
input shape: torch.Size([1, 457])


369it [00:57,  6.66it/s]

loss: nan
input shape: torch.Size([1, 370])
loss: nan
input shape: torch.Size([1, 62])


371it [00:57,  6.63it/s]

loss: nan
input shape: torch.Size([1, 557])
loss: nan
input shape: torch.Size([1, 134])


373it [00:57,  6.94it/s]

loss: nan
input shape: torch.Size([1, 107])
loss: nan
input shape: torch.Size([1, 327])


375it [00:58,  7.02it/s]

loss: nan
input shape: torch.Size([1, 326])
loss: nan
input shape: torch.Size([1, 257])


377it [00:58,  7.04it/s]

loss: nan
input shape: torch.Size([1, 36])
loss: nan
input shape: torch.Size([1, 267])


379it [00:58,  6.90it/s]

loss: nan
input shape: torch.Size([1, 410])
loss: nan
input shape: torch.Size([1, 267])


381it [00:59,  7.00it/s]

loss: nan
input shape: torch.Size([1, 130])
loss: nan
input shape: torch.Size([1, 238])


383it [00:59,  6.76it/s]

loss: nan
input shape: torch.Size([1, 556])
loss: nan
input shape: torch.Size([1, 84])


385it [00:59,  6.35it/s]

loss: nan
input shape: torch.Size([1, 687])
loss: nan
input shape: torch.Size([1, 381])


387it [00:59,  6.71it/s]

loss: nan
input shape: torch.Size([1, 205])
loss: nan
input shape: torch.Size([1, 451])


389it [01:00,  6.48it/s]

loss: nan
input shape: torch.Size([1, 573])
loss: nan
input shape: torch.Size([1, 515])


391it [01:00,  6.41it/s]

loss: nan
input shape: torch.Size([1, 511])
loss: nan
input shape: torch.Size([1, 380])


393it [01:00,  6.49it/s]

loss: nan
input shape: torch.Size([1, 508])
loss: nan
input shape: torch.Size([1, 291])


395it [01:01,  6.53it/s]

loss: nan
input shape: torch.Size([1, 542])
loss: nan
input shape: torch.Size([1, 237])


397it [01:01,  6.87it/s]

loss: nan
input shape: torch.Size([1, 134])
loss: nan
input shape: torch.Size([1, 145])


399it [01:01,  7.03it/s]

loss: nan
input shape: torch.Size([1, 319])
loss: nan
input shape: torch.Size([1, 295])


401it [01:02,  6.72it/s]

loss: nan
input shape: torch.Size([1, 476])
loss: nan
input shape: torch.Size([1, 467])


403it [01:02,  6.81it/s]

loss: nan
input shape: torch.Size([1, 200])
loss: nan
input shape: torch.Size([1, 236])


405it [01:02,  6.93it/s]

loss: nan
input shape: torch.Size([1, 273])
loss: nan
input shape: torch.Size([1, 354])


407it [01:02,  7.05it/s]

loss: nan
input shape: torch.Size([1, 213])
loss: nan
input shape: torch.Size([1, 192])


408it [01:03,  7.14it/s]

loss: nan
input shape: torch.Size([1, 771])


410it [01:03,  6.39it/s]

loss: nan
input shape: torch.Size([1, 194])
loss: nan
input shape: torch.Size([1, 224])


412it [01:03,  6.37it/s]

loss: nan
input shape: torch.Size([1, 557])
loss: nan
input shape: torch.Size([1, 353])


414it [01:04,  6.44it/s]

loss: nan
input shape: torch.Size([1, 438])
loss: nan
input shape: torch.Size([1, 443])


416it [01:04,  6.59it/s]

loss: nan
input shape: torch.Size([1, 132])
loss: nan
input shape: torch.Size([1, 400])


418it [01:04,  6.75it/s]

loss: nan
input shape: torch.Size([1, 276])
loss: nan
input shape: torch.Size([1, 115])


420it [01:04,  6.94it/s]

loss: nan
input shape: torch.Size([1, 192])
loss: nan
input shape: torch.Size([1, 553])


422it [01:05,  6.88it/s]

loss: nan
input shape: torch.Size([1, 242])
loss: nan
input shape: torch.Size([1, 127])


424it [01:05,  6.88it/s]

loss: nan
input shape: torch.Size([1, 412])
loss: nan
input shape: torch.Size([1, 683])


426it [01:05,  6.34it/s]

loss: nan
input shape: torch.Size([1, 535])
loss: nan
input shape: torch.Size([1, 307])


428it [01:06,  6.47it/s]

loss: nan
input shape: torch.Size([1, 536])
loss: nan
input shape: torch.Size([1, 204])


430it [01:06,  6.80it/s]

loss: nan
input shape: torch.Size([1, 382])
loss: nan
input shape: torch.Size([1, 166])


432it [01:06,  6.87it/s]

loss: nan
input shape: torch.Size([1, 433])
loss: nan
input shape: torch.Size([1, 327])


434it [01:06,  6.72it/s]

loss: nan
input shape: torch.Size([1, 570])
loss: nan
input shape: torch.Size([1, 627])


436it [01:07,  6.64it/s]

loss: nan
input shape: torch.Size([1, 257])
loss: nan
input shape: torch.Size([1, 361])


438it [01:07,  6.81it/s]

loss: nan
input shape: torch.Size([1, 416])
loss: nan
input shape: torch.Size([1, 224])


440it [01:07,  6.99it/s]

loss: nan
input shape: torch.Size([1, 168])
loss: nan
input shape: torch.Size([1, 147])


442it [01:08,  7.03it/s]

loss: nan
input shape: torch.Size([1, 258])
loss: nan
input shape: torch.Size([1, 504])


444it [01:08,  6.84it/s]

loss: nan
input shape: torch.Size([1, 303])
loss: nan
input shape: torch.Size([1, 657])


446it [01:08,  6.56it/s]

loss: nan
input shape: torch.Size([1, 332])
loss: nan
input shape: torch.Size([1, 434])


448it [01:09,  6.50it/s]

loss: nan
input shape: torch.Size([1, 542])
loss: nan
input shape: torch.Size([1, 543])


450it [01:09,  6.58it/s]

loss: nan
input shape: torch.Size([1, 346])
loss: nan
input shape: torch.Size([1, 333])


452it [01:09,  6.72it/s]

loss: nan
input shape: torch.Size([1, 358])
loss: nan
input shape: torch.Size([1, 239])


454it [01:09,  7.01it/s]

loss: nan
input shape: torch.Size([1, 325])
loss: nan
input shape: torch.Size([1, 545])


456it [01:10,  6.53it/s]

loss: nan
input shape: torch.Size([1, 594])
loss: nan
input shape: torch.Size([1, 233])


458it [01:10,  6.85it/s]

loss: nan
input shape: torch.Size([1, 218])
loss: nan
input shape: torch.Size([1, 192])


460it [01:10,  6.96it/s]

loss: nan
input shape: torch.Size([1, 189])
loss: nan
input shape: torch.Size([1, 218])


462it [01:11,  7.09it/s]

loss: nan
input shape: torch.Size([1, 232])
loss: nan
input shape: torch.Size([1, 355])


464it [01:11,  6.81it/s]

loss: nan
input shape: torch.Size([1, 515])
loss: nan
input shape: torch.Size([1, 333])


466it [01:11,  7.01it/s]

loss: nan
input shape: torch.Size([1, 249])
loss: nan
input shape: torch.Size([1, 328])


468it [01:11,  7.15it/s]

loss: nan
input shape: torch.Size([1, 184])
loss: nan
input shape: torch.Size([1, 200])


470it [01:12,  7.15it/s]

loss: nan
input shape: torch.Size([1, 223])
loss: nan
input shape: torch.Size([1, 466])


472it [01:12,  7.01it/s]

loss: nan
input shape: torch.Size([1, 309])
loss: nan
input shape: torch.Size([1, 193])


474it [01:12,  7.08it/s]

loss: nan
input shape: torch.Size([1, 269])
loss: nan
input shape: torch.Size([1, 498])


476it [01:13,  6.84it/s]

loss: nan
input shape: torch.Size([1, 322])
loss: nan
input shape: torch.Size([1, 279])


478it [01:13,  6.94it/s]

loss: nan
input shape: torch.Size([1, 261])
loss: nan
input shape: torch.Size([1, 215])


480it [01:13,  6.89it/s]

loss: nan
input shape: torch.Size([1, 241])
loss: nan
input shape: torch.Size([1, 173])


482it [01:14,  6.84it/s]

loss: nan
input shape: torch.Size([1, 384])
loss: nan
input shape: torch.Size([1, 127])


484it [01:14,  6.71it/s]

loss: nan
input shape: torch.Size([1, 437])
loss: nan
input shape: torch.Size([1, 253])


486it [01:14,  6.45it/s]

loss: nan
input shape: torch.Size([1, 593])
loss: nan
input shape: torch.Size([1, 309])


488it [01:14,  6.41it/s]

loss: nan
input shape: torch.Size([1, 466])
loss: nan
input shape: torch.Size([1, 242])


490it [01:15,  6.63it/s]

loss: nan
input shape: torch.Size([1, 148])
loss: nan
input shape: torch.Size([1, 350])


492it [01:15,  6.61it/s]

loss: nan
input shape: torch.Size([1, 536])
loss: nan
input shape: torch.Size([1, 474])


494it [01:15,  6.64it/s]

loss: nan
input shape: torch.Size([1, 371])
loss: nan
input shape: torch.Size([1, 618])


496it [01:16,  6.53it/s]

loss: nan
input shape: torch.Size([1, 235])
loss: nan
input shape: torch.Size([1, 192])


498it [01:16,  6.67it/s]

loss: nan
input shape: torch.Size([1, 401])
loss: nan
input shape: torch.Size([1, 344])


500it [01:16,  6.72it/s]

loss: nan
input shape: torch.Size([1, 258])
loss: nan
input shape: torch.Size([1, 559])


502it [01:17,  6.68it/s]

loss: nan
input shape: torch.Size([1, 350])
loss: nan
input shape: torch.Size([1, 390])


504it [01:17,  6.69it/s]

loss: nan
input shape: torch.Size([1, 136])
loss: nan
input shape: torch.Size([1, 596])


506it [01:17,  6.52it/s]

loss: nan
input shape: torch.Size([1, 469])
loss: nan
input shape: torch.Size([1, 283])


508it [01:17,  6.73it/s]

loss: nan
input shape: torch.Size([1, 180])
loss: nan
input shape: torch.Size([1, 485])


510it [01:18,  6.67it/s]

loss: nan
input shape: torch.Size([1, 412])
loss: nan
input shape: torch.Size([1, 261])


512it [01:18,  6.58it/s]

loss: nan
input shape: torch.Size([1, 391])
loss: nan
input shape: torch.Size([1, 222])


514it [01:18,  6.51it/s]

loss: nan
input shape: torch.Size([1, 590])
loss: nan
input shape: torch.Size([1, 249])


516it [01:19,  6.60it/s]

loss: nan
input shape: torch.Size([1, 481])
loss: nan
input shape: torch.Size([1, 134])


518it [01:19,  6.60it/s]

loss: nan
input shape: torch.Size([1, 452])
loss: nan
input shape: torch.Size([1, 485])


520it [01:19,  6.43it/s]

loss: nan
input shape: torch.Size([1, 488])
loss: nan
input shape: torch.Size([1, 374])


522it [01:20,  6.43it/s]

loss: nan
input shape: torch.Size([1, 574])
loss: nan
input shape: torch.Size([1, 595])


524it [01:20,  6.20it/s]

loss: nan
input shape: torch.Size([1, 612])
loss: nan
input shape: torch.Size([1, 259])


526it [01:20,  6.47it/s]

loss: nan
input shape: torch.Size([1, 245])
loss: nan
input shape: torch.Size([1, 368])


528it [01:21,  6.58it/s]

loss: nan
input shape: torch.Size([1, 209])
loss: nan
input shape: torch.Size([1, 71])


530it [01:21,  6.62it/s]

loss: nan
input shape: torch.Size([1, 377])
loss: nan
input shape: torch.Size([1, 208])


532it [01:21,  6.62it/s]

loss: nan
input shape: torch.Size([1, 413])
loss: nan
input shape: torch.Size([1, 152])


534it [01:21,  6.69it/s]

loss: nan
input shape: torch.Size([1, 189])
loss: nan
input shape: torch.Size([1, 423])


536it [01:22,  6.73it/s]

loss: nan
input shape: torch.Size([1, 325])
loss: nan
input shape: torch.Size([1, 199])


538it [01:22,  6.78it/s]

loss: nan
input shape: torch.Size([1, 243])
loss: nan
input shape: torch.Size([1, 614])


540it [01:22,  6.38it/s]

loss: nan
input shape: torch.Size([1, 543])
loss: nan
input shape: torch.Size([1, 470])


542it [01:23,  6.25it/s]

loss: nan
input shape: torch.Size([1, 489])
loss: nan
input shape: torch.Size([1, 686])


544it [01:23,  6.02it/s]

loss: nan
input shape: torch.Size([1, 545])
loss: nan
input shape: torch.Size([1, 493])


546it [01:23,  6.19it/s]

loss: nan
input shape: torch.Size([1, 419])
loss: nan
input shape: torch.Size([1, 269])


548it [01:24,  6.48it/s]

loss: nan
input shape: torch.Size([1, 132])
loss: nan
input shape: torch.Size([1, 303])


550it [01:24,  6.59it/s]

loss: nan
input shape: torch.Size([1, 448])
loss: nan
input shape: torch.Size([1, 249])


552it [01:24,  6.66it/s]

loss: nan
input shape: torch.Size([1, 60])
loss: nan
input shape: torch.Size([1, 255])


554it [01:25,  6.69it/s]

loss: nan
input shape: torch.Size([1, 220])
loss: nan
input shape: torch.Size([1, 191])


556it [01:25,  6.68it/s]

loss: nan
input shape: torch.Size([1, 424])
loss: nan
input shape: torch.Size([1, 285])


558it [01:25,  6.66it/s]

loss: nan
input shape: torch.Size([1, 410])
loss: nan
input shape: torch.Size([1, 211])


560it [01:25,  6.66it/s]

loss: nan
input shape: torch.Size([1, 273])
loss: nan
input shape: torch.Size([1, 212])


562it [01:26,  6.63it/s]

loss: nan
input shape: torch.Size([1, 457])
loss: nan
input shape: torch.Size([1, 490])


564it [01:26,  6.51it/s]

loss: nan
input shape: torch.Size([1, 295])
loss: nan
input shape: torch.Size([1, 566])


566it [01:26,  6.39it/s]

loss: nan
input shape: torch.Size([1, 200])
loss: nan
input shape: torch.Size([1, 442])


568it [01:27,  6.41it/s]

loss: nan
input shape: torch.Size([1, 184])
loss: nan
input shape: torch.Size([1, 364])


570it [01:27,  6.46it/s]

loss: nan
input shape: torch.Size([1, 332])
loss: nan
input shape: torch.Size([1, 486])


572it [01:27,  6.44it/s]

loss: nan
input shape: torch.Size([1, 158])
loss: nan
input shape: torch.Size([1, 300])


574it [01:28,  6.52it/s]

loss: nan
input shape: torch.Size([1, 256])
loss: nan
input shape: torch.Size([1, 130])


576it [01:28,  6.51it/s]

loss: nan
input shape: torch.Size([1, 410])
loss: nan
input shape: torch.Size([1, 348])


578it [01:28,  6.36it/s]

loss: nan
input shape: torch.Size([1, 588])
loss: nan
input shape: torch.Size([1, 859])


580it [01:29,  5.99it/s]

loss: nan
input shape: torch.Size([1, 166])
loss: nan
input shape: torch.Size([1, 566])


582it [01:29,  6.20it/s]

loss: nan
input shape: torch.Size([1, 251])
loss: nan
input shape: torch.Size([1, 353])


584it [01:29,  6.29it/s]

loss: nan
input shape: torch.Size([1, 337])
loss: nan
input shape: torch.Size([1, 304])


586it [01:30,  6.20it/s]

loss: nan
input shape: torch.Size([1, 516])
loss: nan
input shape: torch.Size([1, 121])


588it [01:30,  6.28it/s]

loss: nan
input shape: torch.Size([1, 416])
loss: nan
input shape: torch.Size([1, 313])


590it [01:30,  6.18it/s]

loss: nan
input shape: torch.Size([1, 615])
loss: nan
input shape: torch.Size([1, 229])


592it [01:30,  6.20it/s]

loss: nan
input shape: torch.Size([1, 230])
loss: nan
input shape: torch.Size([1, 472])


592it [01:31,  6.50it/s]


KeyboardInterrupt: 

In [None]:
del model 