In [1]:
import matplotlib.pyplot as plt
from tqdm import trange
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np


# Assuming the npy files are named 'delta_0.npy', 'delta_1.npy', etc.

start_n = 0
end_n = 799

# Specify the index of the item to plot
item_index = 2500

# Initialize a list to store the item values
item_values = []
vector_values = []

# Load the item values from the npy files
for i in trange(start_n, end_n):
    vector = np.load(f'/user/as6154/dissert/L12_half_data/delta_{i}.npy')

    item_values.append(vector[item_index])
    vector_values.append(vector)

vector_values = np.array(vector_values)
# # Plot the item values
# plt.figure(figsize=(10, 6))
# plt.plot(item_values, label=f'Item {item_index}')

# plt.xlabel('File Index')
# plt.ylabel('Item Value')
# plt.title(f'Item {item_index} from NPY Files')
# plt.legend()
# plt.show()

100%|██████████| 799/799 [00:00<00:00, 1685.71it/s]


In [2]:
len(vector)

4096

In [3]:
import pickle
number_of_bins = 100000
dir_path = "/user/as6154/dissert/szhalf_L12_delta_-2_to_0_interval_400_secondDelta_1_to_1_interval_1_data"

# Define the path to save the pickle file
pickle_file_path = f'{dir_path}/{number_of_bins}_binned_values.pkl'
with open(pickle_file_path, 'rb') as f:
    binned_values_loaded = pickle.load(f)

print(f"Binned values loaded from pickle file (first 10): {binned_values_loaded[:10]}")
print(f"Binned values (last 10): {binned_values_loaded[-10:]}")  # Display first 10 binned values for reference

Binned values loaded from pickle file (first 10): [[99999 11065 11065 ... 11065 11065 11065]
 [99999 11065 11065 ... 11065 11065 11065]
 [99999 11065 11065 ... 11065 11065 11065]
 ...
 [99999 11065 11065 ... 11065 11065 11065]
 [99999 11065 11065 ... 11065 11065 11065]
 [99999 11065 11065 ... 11065 11065 11065]]
Binned values (last 10): [[11065 11065 11065 ... 11065 11065 11065]
 [11065 11065 11065 ... 11065 11065 11065]
 [11065 11065 11065 ... 11065 11065 11065]
 ...
 [11065 11065 11065 ... 11065 11065 11065]
 [11065 11065 11065 ... 11065 11065 11065]
 [11065 11065 11065 ... 11065 11065 11065]]


In [4]:

from sklearn.model_selection import train_test_split

sequence_length = 10
batch_size = 4

class HighDimensionalDataset(Dataset):
    def __init__(self, data, n_items = 1000, random = False):
        self.data = data
        self.random = random
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.random:
            start_index = np.random.randint(0, len(self.data) - sequence_length)
        else:
            start_index = min(idx, len(self.data) - sequence_length)
        sequence = self.data[start_index : start_index + sequence_length]
        return torch.tensor(sequence, dtype=torch.float32)



# Split the data into training and testing sets
train_data = binned_values_loaded
test_data = binned_values_loaded

# TODO: But this test_dataset would also work if it's overfitting?
train_dataset = HighDimensionalDataset(train_data, n_items = 1000, random = True)
test_dataset = HighDimensionalDataset(test_data, n_items = 1000, random = False)


# Create dataloaders for training and testing
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [5]:
train_dataset[100].shape

torch.Size([10, 4096])

Setting:

1. We have a sequence 800 data, that is across delta changing, each item is vector dim 4096.
2. I want to predict next data, conditioning on previous data. 
3. So my input and output is dim 4096. 

Idea is to 


-----

Normal language models take in batch of text, text is in high dimension, but it's a number alone.
so it's BxN.

Our's is BxNxD.

So i need to embed the D to 1D?


Can we embed continous instead of quantize? 
- but we need to tokenize first right?
- unless i train my own tokenizer with bpe?
- so we shouldnt?

In [6]:
train_dataset[0].shape[1]

4096

In [7]:
# Garg inspired model
from transformers import GPT2Config, GPT2Model
import torch.nn as nn


class Transformer(nn.Module):
    def __init__(self, n_dims, n_positions, n_embd=4096, n_layer=12, n_head=4):
        super(Transformer, self).__init__()
        configuration = GPT2Config(
            n_positions=2 * n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        
        self._read_in = nn.Linear(n_dims, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, n_dims)

    def forward(self, x):
        embeds = self._read_in(x)
        output = self._backbone(inputs_embeds = embeds).last_hidden_state
        prediction = self._read_out(output)
        return prediction
    
n_dims = train_dataset[0].shape[1]
n_positions = 1024

model = Transformer(n_dims, n_positions)

input_data = torch.randn(batch_size, sequence_length, n_dims)

output = model(input_data)
print(output.shape)

[2024-06-26 22:06:00,446] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
torch.Size([4, 10, 4096])


In [8]:
next(iter(train_dataloader)).shape

torch.Size([4, 10, 4096])

In [9]:
next(iter(test_dataloader)).shape

torch.Size([1, 10, 4096])

In [10]:
import wandb
from transformers import AdamW, get_linear_schedule_with_warmup
import pandas as pd


wandb.init(project="GPT_QFT")
loss_fn = nn.MSELoss()
device = "cuda:1"
num_epochs = 100
eval_steps = 100

optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)

# optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=len(train_dataloader) * num_epochs)



def evaluate_model(model, test_dataloader, device, loss_fn):
    loss_df = pd.DataFrame()
    model.eval()
    with torch.no_grad():
        loss_array = []
        for i, (test_input_seq) in enumerate(test_dataloader):
            test_target_seq = test_input_seq[:, 1:, :]
            test_input_seq = test_input_seq.to(device)
            test_target_seq = test_target_seq.to(device)

            test_output = model(test_input_seq)
            test_output = test_output[:, 1:, :]

            test_loss = loss_fn(test_output, test_target_seq)
            loss_array.append((i, test_loss.item()))
        loss_df = pd.DataFrame(loss_array, columns=['Index', 'Loss'])
        loss_df.drop(columns=['Index'], inplace=True)
        loss_df = loss_df.transpose()
        loss_df.to_csv('loss_data_n2_100000.csv', mode='a', header=False, index=False)

        loss_array.sort(key=lambda x: x[1], reverse=True)
        print(f'Highest Loss at index: {loss_array[0][0]}, Loss: {loss_array[0][1]}')
        print(f'Lowest Loss at index: {loss_array[-1][0]}, Loss: {loss_array[-1][1]}')
        wandb.log({"Highest Loss Index": loss_array[0][0], "Lowest Loss Index": loss_array[-1][0]})
    model.train()

for epoch in trange(num_epochs):
    for i, (input_seq) in enumerate(train_dataloader):

        target_seq = input_seq[:, 1:, :]

        model.to(device)
        input_seq = input_seq.to(device)
        target_seq = target_seq.to(device)
        output = model(input_seq)
        output = output[:, 1:, :]

        loss = loss_fn(output, target_seq)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if (i+1) % eval_steps == 0:
            evaluate_model(model, test_dataloader, device, loss_fn)
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item()}')
            wandb.log({"Loss": loss.item()})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandrewsiah[0m ([33mhong-exploration[0m). Use [1m`wandb login --relogin`[0m to force relogin


  1%|          | 1/100 [00:28<46:14, 28.02s/it]

Highest Loss at index: 0, Loss: 124377784.0
Lowest Loss at index: 266, Loss: 123781632.0
Epoch [1/100], Step [100/100], Loss: 124013880.0


  2%|▏         | 2/100 [00:53<43:18, 26.51s/it]

Highest Loss at index: 0, Loss: 123385712.0
Lowest Loss at index: 266, Loss: 122791840.0
Epoch [2/100], Step [100/100], Loss: 123163096.0


  3%|▎         | 3/100 [01:19<42:23, 26.22s/it]

Highest Loss at index: 0, Loss: 122344224.0
Lowest Loss at index: 266, Loss: 121752664.0
Epoch [3/100], Step [100/100], Loss: 122116552.0


  4%|▍         | 4/100 [01:47<43:26, 27.15s/it]

Highest Loss at index: 0, Loss: 121248056.0
Lowest Loss at index: 266, Loss: 120659032.0
Epoch [4/100], Step [100/100], Loss: 121259280.0


  5%|▌         | 5/100 [02:15<43:28, 27.46s/it]

Highest Loss at index: 0, Loss: 120100712.0
Lowest Loss at index: 266, Loss: 119514528.0
Epoch [5/100], Step [100/100], Loss: 120112424.0


  6%|▌         | 6/100 [02:44<43:45, 27.93s/it]

Highest Loss at index: 0, Loss: 118906824.0
Lowest Loss at index: 266, Loss: 118323560.0
Epoch [6/100], Step [100/100], Loss: 118457120.0


  7%|▋         | 7/100 [03:12<43:25, 28.01s/it]

Highest Loss at index: 0, Loss: 117671648.0
Lowest Loss at index: 266, Loss: 117091224.0
Epoch [7/100], Step [100/100], Loss: 117684176.0


  8%|▊         | 8/100 [03:40<42:57, 28.02s/it]

Highest Loss at index: 0, Loss: 116399976.0
Lowest Loss at index: 266, Loss: 115822320.0
Epoch [8/100], Step [100/100], Loss: 116168848.0


  9%|▉         | 9/100 [04:08<42:17, 27.89s/it]

Highest Loss at index: 0, Loss: 115096592.0
Lowest Loss at index: 266, Loss: 114522056.0
Epoch [9/100], Step [100/100], Loss: 114996384.0


 10%|█         | 10/100 [04:35<41:36, 27.74s/it]

Highest Loss at index: 0, Loss: 113765776.0
Lowest Loss at index: 266, Loss: 113194880.0
Epoch [10/100], Step [100/100], Loss: 113439984.0


 11%|█         | 11/100 [05:03<40:52, 27.55s/it]

Highest Loss at index: 0, Loss: 112411592.0
Lowest Loss at index: 266, Loss: 111844184.0
Epoch [11/100], Step [100/100], Loss: 112199976.0


 12%|█▏        | 12/100 [05:30<40:11, 27.40s/it]

Highest Loss at index: 0, Loss: 111038208.0
Lowest Loss at index: 266, Loss: 110473760.0
Epoch [12/100], Step [100/100], Loss: 110830304.0


 13%|█▎        | 13/100 [05:57<39:39, 27.35s/it]

Highest Loss at index: 0, Loss: 109648456.0
Lowest Loss at index: 266, Loss: 109087832.0
Epoch [13/100], Step [100/100], Loss: 109330696.0


 14%|█▍        | 14/100 [06:24<39:15, 27.39s/it]

Highest Loss at index: 0, Loss: 108245992.0
Lowest Loss at index: 266, Loss: 107688776.0
Epoch [14/100], Step [100/100], Loss: 108038472.0


 15%|█▌        | 15/100 [06:51<38:36, 27.25s/it]

Highest Loss at index: 0, Loss: 106833512.0
Lowest Loss at index: 266, Loss: 106279528.0
Epoch [15/100], Step [100/100], Loss: 106620800.0


 16%|█▌        | 16/100 [07:19<38:27, 27.47s/it]

Highest Loss at index: 0, Loss: 105413704.0
Lowest Loss at index: 266, Loss: 104863264.0
Epoch [16/100], Step [100/100], Loss: 105313064.0


 17%|█▋        | 17/100 [07:47<38:03, 27.51s/it]

Highest Loss at index: 0, Loss: 103989280.0
Lowest Loss at index: 266, Loss: 103442704.0
Epoch [17/100], Step [100/100], Loss: 103687384.0


 18%|█▊        | 18/100 [08:14<37:25, 27.39s/it]

Highest Loss at index: 0, Loss: 102562600.0
Lowest Loss at index: 266, Loss: 102019640.0
Epoch [18/100], Step [100/100], Loss: 102122216.0


 19%|█▉        | 19/100 [08:42<37:02, 27.43s/it]

Highest Loss at index: 0, Loss: 101135712.0
Lowest Loss at index: 266, Loss: 100596648.0
Epoch [19/100], Step [100/100], Loss: 100927816.0


 20%|██        | 20/100 [09:09<36:32, 27.40s/it]

Highest Loss at index: 0, Loss: 99710576.0
Lowest Loss at index: 266, Loss: 99175336.0
Epoch [20/100], Step [100/100], Loss: 99507216.0


 21%|██        | 21/100 [09:36<35:57, 27.31s/it]

Highest Loss at index: 0, Loss: 98289112.0
Lowest Loss at index: 266, Loss: 97757568.0
Epoch [21/100], Step [100/100], Loss: 98198728.0


 22%|██▏       | 22/100 [10:03<35:28, 27.29s/it]

Highest Loss at index: 0, Loss: 96873128.0
Lowest Loss at index: 266, Loss: 96345728.0
Epoch [22/100], Step [100/100], Loss: 96887272.0


 23%|██▎       | 23/100 [10:30<34:59, 27.27s/it]

Highest Loss at index: 0, Loss: 95464432.0
Lowest Loss at index: 266, Loss: 94940536.0
Epoch [23/100], Step [100/100], Loss: 95163120.0


 24%|██▍       | 24/100 [10:58<34:31, 27.26s/it]

Highest Loss at index: 0, Loss: 94064040.0
Lowest Loss at index: 266, Loss: 93543960.0
Epoch [24/100], Step [100/100], Loss: 93857336.0


 25%|██▌       | 25/100 [11:25<33:57, 27.17s/it]

Highest Loss at index: 0, Loss: 92673760.0
Lowest Loss at index: 266, Loss: 92157296.0
Epoch [25/100], Step [100/100], Loss: 92362152.0


 26%|██▌       | 26/100 [11:51<33:13, 26.94s/it]

Highest Loss at index: 0, Loss: 91294592.0
Lowest Loss at index: 266, Loss: 90781800.0
Epoch [26/100], Step [100/100], Loss: 90994048.0


 27%|██▋       | 27/100 [12:17<32:25, 26.65s/it]

Highest Loss at index: 0, Loss: 89927296.0
Lowest Loss at index: 266, Loss: 89418208.0
Epoch [27/100], Step [100/100], Loss: 89940912.0


 28%|██▊       | 28/100 [12:43<31:41, 26.41s/it]

Highest Loss at index: 0, Loss: 88573408.0
Lowest Loss at index: 266, Loss: 88068008.0
Epoch [28/100], Step [100/100], Loss: 88177720.0


 29%|██▉       | 29/100 [13:09<31:02, 26.23s/it]

Highest Loss at index: 0, Loss: 87233512.0
Lowest Loss at index: 266, Loss: 86731944.0
Epoch [29/100], Step [100/100], Loss: 87141608.0


 30%|███       | 30/100 [13:34<30:24, 26.06s/it]

Highest Loss at index: 0, Loss: 85908520.0
Lowest Loss at index: 266, Loss: 85410920.0
Epoch [30/100], Step [100/100], Loss: 85826064.0


 31%|███       | 31/100 [14:00<29:47, 25.90s/it]

Highest Loss at index: 0, Loss: 84599600.0
Lowest Loss at index: 266, Loss: 84105688.0
Epoch [31/100], Step [100/100], Loss: 84418368.0


 32%|███▏      | 32/100 [14:26<29:16, 25.83s/it]

Highest Loss at index: 0, Loss: 83307240.0
Lowest Loss at index: 266, Loss: 82816720.0
Epoch [32/100], Step [100/100], Loss: 82944704.0


 33%|███▎      | 33/100 [14:51<28:50, 25.83s/it]

Highest Loss at index: 0, Loss: 82031672.0
Lowest Loss at index: 266, Loss: 81544712.0
Epoch [33/100], Step [100/100], Loss: 81733928.0


 34%|███▍      | 34/100 [15:17<28:22, 25.79s/it]

Highest Loss at index: 0, Loss: 80773760.0
Lowest Loss at index: 266, Loss: 80290496.0
Epoch [34/100], Step [100/100], Loss: 80695008.0


 35%|███▌      | 35/100 [15:43<27:51, 25.71s/it]

Highest Loss at index: 0, Loss: 79533856.0
Lowest Loss at index: 266, Loss: 79054224.0
Epoch [35/100], Step [100/100], Loss: 79448000.0


 36%|███▌      | 36/100 [16:08<27:23, 25.68s/it]

Highest Loss at index: 0, Loss: 78312688.0
Lowest Loss at index: 266, Loss: 77836832.0
Epoch [36/100], Step [100/100], Loss: 78233112.0


 37%|███▋      | 37/100 [16:34<26:58, 25.69s/it]

Highest Loss at index: 0, Loss: 77110632.0
Lowest Loss at index: 266, Loss: 76638136.0
Epoch [37/100], Step [100/100], Loss: 77024424.0


 38%|███▊      | 38/100 [16:59<26:27, 25.60s/it]

Highest Loss at index: 0, Loss: 75927864.0
Lowest Loss at index: 266, Loss: 75458920.0
Epoch [38/100], Step [100/100], Loss: 75652760.0


 39%|███▉      | 39/100 [17:25<26:03, 25.63s/it]

Highest Loss at index: 0, Loss: 74764600.0
Lowest Loss at index: 266, Loss: 74298824.0
Epoch [39/100], Step [100/100], Loss: 74497240.0


 40%|████      | 40/100 [17:51<25:37, 25.62s/it]

Highest Loss at index: 0, Loss: 73621032.0
Lowest Loss at index: 266, Loss: 73158960.0
Epoch [40/100], Step [100/100], Loss: 73538792.0


 41%|████      | 41/100 [18:16<25:07, 25.56s/it]

Highest Loss at index: 0, Loss: 72497608.0
Lowest Loss at index: 266, Loss: 72038944.0
Epoch [41/100], Step [100/100], Loss: 72421712.0


 42%|████▏     | 42/100 [18:41<24:38, 25.49s/it]

Highest Loss at index: 0, Loss: 71394600.0
Lowest Loss at index: 266, Loss: 70939592.0
Epoch [42/100], Step [100/100], Loss: 71317480.0


 43%|████▎     | 43/100 [19:07<24:15, 25.53s/it]

Highest Loss at index: 0, Loss: 70312360.0
Lowest Loss at index: 266, Loss: 69860768.0
Epoch [43/100], Step [100/100], Loss: 70323088.0


 44%|████▍     | 44/100 [19:33<23:50, 25.54s/it]

Highest Loss at index: 0, Loss: 69250536.0
Lowest Loss at index: 266, Loss: 68802432.0
Epoch [44/100], Step [100/100], Loss: 69083464.0


 45%|████▌     | 45/100 [19:58<23:26, 25.57s/it]

Highest Loss at index: 0, Loss: 68209520.0
Lowest Loss at index: 266, Loss: 67764256.0
Epoch [45/100], Step [100/100], Loss: 67941352.0


 46%|████▌     | 46/100 [20:24<23:00, 25.57s/it]

Highest Loss at index: 0, Loss: 67189280.0
Lowest Loss at index: 266, Loss: 66747536.0
Epoch [46/100], Step [100/100], Loss: 67022592.0


 47%|████▋     | 47/100 [20:49<22:33, 25.54s/it]

Highest Loss at index: 0, Loss: 66189840.0
Lowest Loss at index: 266, Loss: 65751076.0
Epoch [47/100], Step [100/100], Loss: 66199736.0


 48%|████▊     | 48/100 [21:15<22:07, 25.52s/it]

Highest Loss at index: 0, Loss: 65211884.0
Lowest Loss at index: 266, Loss: 64776244.0
Epoch [48/100], Step [100/100], Loss: 65133576.0


 49%|████▉     | 49/100 [21:40<21:41, 25.52s/it]

Highest Loss at index: 0, Loss: 64254804.0
Lowest Loss at index: 266, Loss: 63822324.0
Epoch [49/100], Step [100/100], Loss: 64175160.0


 50%|█████     | 50/100 [22:06<21:16, 25.53s/it]

Highest Loss at index: 0, Loss: 63318480.0
Lowest Loss at index: 266, Loss: 62889168.0
Epoch [50/100], Step [100/100], Loss: 62997128.0


 51%|█████     | 51/100 [22:31<20:51, 25.54s/it]

Highest Loss at index: 0, Loss: 62402440.0
Lowest Loss at index: 266, Loss: 61976200.0
Epoch [51/100], Step [100/100], Loss: 62322092.0


 52%|█████▏    | 52/100 [22:57<20:27, 25.58s/it]

Highest Loss at index: 0, Loss: 61507528.0
Lowest Loss at index: 266, Loss: 61084352.0
Epoch [52/100], Step [100/100], Loss: 61434588.0


 53%|█████▎    | 53/100 [23:23<20:02, 25.59s/it]

Highest Loss at index: 0, Loss: 60633360.0
Lowest Loss at index: 266, Loss: 60213432.0
Epoch [53/100], Step [100/100], Loss: 60471608.0


 54%|█████▍    | 54/100 [23:48<19:36, 25.57s/it]

Highest Loss at index: 0, Loss: 59780240.0
Lowest Loss at index: 266, Loss: 59363112.0
Epoch [54/100], Step [100/100], Loss: 59542784.0


 55%|█████▌    | 55/100 [24:14<19:10, 25.57s/it]

Highest Loss at index: 0, Loss: 58947964.0
Lowest Loss at index: 266, Loss: 58534124.0
Epoch [55/100], Step [100/100], Loss: 58734948.0


 56%|█████▌    | 56/100 [24:39<18:44, 25.55s/it]

Highest Loss at index: 0, Loss: 58136256.0
Lowest Loss at index: 266, Loss: 57725364.0
Epoch [56/100], Step [100/100], Loss: 58058340.0


 57%|█████▋    | 57/100 [25:05<18:19, 25.56s/it]

Highest Loss at index: 0, Loss: 57344788.0
Lowest Loss at index: 266, Loss: 56936564.0
Epoch [57/100], Step [100/100], Loss: 57188848.0


 58%|█████▊    | 58/100 [25:31<17:57, 25.65s/it]

Highest Loss at index: 0, Loss: 56573728.0
Lowest Loss at index: 266, Loss: 56168144.0
Epoch [58/100], Step [100/100], Loss: 56489536.0


 59%|█████▉    | 59/100 [25:56<17:31, 25.64s/it]

Highest Loss at index: 0, Loss: 55823228.0
Lowest Loss at index: 266, Loss: 55420196.0
Epoch [59/100], Step [100/100], Loss: 55659628.0


 60%|██████    | 60/100 [26:22<17:06, 25.66s/it]

Highest Loss at index: 0, Loss: 55093072.0
Lowest Loss at index: 266, Loss: 54692416.0
Epoch [60/100], Step [100/100], Loss: 54962352.0


 61%|██████    | 61/100 [26:48<16:39, 25.64s/it]

Highest Loss at index: 0, Loss: 54383260.0
Lowest Loss at index: 266, Loss: 53984940.0
Epoch [61/100], Step [100/100], Loss: 54072748.0


 62%|██████▏   | 62/100 [27:13<16:17, 25.73s/it]

Highest Loss at index: 0, Loss: 53693368.0
Lowest Loss at index: 266, Loss: 53297284.0
Epoch [62/100], Step [100/100], Loss: 53542628.0


 63%|██████▎   | 63/100 [27:39<15:50, 25.70s/it]

Highest Loss at index: 0, Loss: 53022996.0
Lowest Loss at index: 266, Loss: 52629360.0
Epoch [63/100], Step [100/100], Loss: 52870144.0


 64%|██████▍   | 64/100 [28:05<15:21, 25.61s/it]

Highest Loss at index: 0, Loss: 52372564.0
Lowest Loss at index: 266, Loss: 51981560.0
Epoch [64/100], Step [100/100], Loss: 52305780.0


 65%|██████▌   | 65/100 [28:30<14:55, 25.59s/it]

Highest Loss at index: 0, Loss: 51742036.0
Lowest Loss at index: 266, Loss: 51353396.0
Epoch [65/100], Step [100/100], Loss: 51593140.0


 66%|██████▌   | 66/100 [28:56<14:30, 25.59s/it]

Highest Loss at index: 0, Loss: 51131116.0
Lowest Loss at index: 266, Loss: 50744668.0
Epoch [66/100], Step [100/100], Loss: 50987968.0


 67%|██████▋   | 67/100 [29:23<14:23, 26.17s/it]

Highest Loss at index: 0, Loss: 50539812.0
Lowest Loss at index: 266, Loss: 50155508.0
Epoch [67/100], Step [100/100], Loss: 50473152.0


 68%|██████▊   | 68/100 [29:49<13:53, 26.06s/it]

Highest Loss at index: 0, Loss: 49967496.0
Lowest Loss at index: 266, Loss: 49585268.0
Epoch [68/100], Step [100/100], Loss: 49824732.0


 69%|██████▉   | 69/100 [30:16<13:36, 26.34s/it]

Highest Loss at index: 0, Loss: 49414012.0
Lowest Loss at index: 266, Loss: 49033736.0
Epoch [69/100], Step [100/100], Loss: 49345016.0


 70%|███████   | 70/100 [30:43<13:16, 26.54s/it]

Highest Loss at index: 0, Loss: 48879592.0
Lowest Loss at index: 266, Loss: 48501424.0
Epoch [70/100], Step [100/100], Loss: 48742580.0


 71%|███████   | 71/100 [31:10<12:54, 26.71s/it]

Highest Loss at index: 0, Loss: 48364360.0
Lowest Loss at index: 266, Loss: 47988216.0
Epoch [71/100], Step [100/100], Loss: 48159720.0


 72%|███████▏  | 72/100 [31:37<12:32, 26.86s/it]

Highest Loss at index: 0, Loss: 47867816.0
Lowest Loss at index: 266, Loss: 47493732.0
Epoch [72/100], Step [100/100], Loss: 47797152.0


 73%|███████▎  | 73/100 [32:05<12:08, 26.99s/it]

Highest Loss at index: 0, Loss: 47390104.0
Lowest Loss at index: 266, Loss: 47017828.0
Epoch [73/100], Step [100/100], Loss: 47173768.0


 74%|███████▍  | 74/100 [32:32<11:43, 27.05s/it]

Highest Loss at index: 0, Loss: 46930488.0
Lowest Loss at index: 266, Loss: 46560148.0
Epoch [74/100], Step [100/100], Loss: 46935000.0


 75%|███████▌  | 75/100 [32:59<11:16, 27.08s/it]

Highest Loss at index: 0, Loss: 46488964.0
Lowest Loss at index: 266, Loss: 46120376.0
Epoch [75/100], Step [100/100], Loss: 46429480.0


 76%|███████▌  | 76/100 [33:26<10:50, 27.11s/it]

Highest Loss at index: 0, Loss: 46065796.0
Lowest Loss at index: 266, Loss: 45698804.0
Epoch [76/100], Step [100/100], Loss: 45850588.0


 77%|███████▋  | 77/100 [33:53<10:23, 27.11s/it]

Highest Loss at index: 0, Loss: 45660804.0
Lowest Loss at index: 266, Loss: 45295508.0
Epoch [77/100], Step [100/100], Loss: 45521740.0


 78%|███████▊  | 78/100 [34:21<09:57, 27.17s/it]

Highest Loss at index: 0, Loss: 45274032.0
Lowest Loss at index: 266, Loss: 44910164.0
Epoch [78/100], Step [100/100], Loss: 45187944.0


 79%|███████▉  | 79/100 [34:48<09:30, 27.16s/it]

Highest Loss at index: 0, Loss: 44905204.0
Lowest Loss at index: 266, Loss: 44542948.0
Epoch [79/100], Step [100/100], Loss: 44765036.0


 80%|████████  | 80/100 [35:15<09:05, 27.29s/it]

Highest Loss at index: 0, Loss: 44553976.0
Lowest Loss at index: 266, Loss: 44193068.0
Epoch [80/100], Step [100/100], Loss: 44338192.0


 81%|████████  | 81/100 [35:42<08:38, 27.27s/it]

Highest Loss at index: 0, Loss: 44220376.0
Lowest Loss at index: 266, Loss: 43860776.0
Epoch [81/100], Step [100/100], Loss: 43995704.0


 82%|████████▏ | 82/100 [36:10<08:11, 27.28s/it]

Highest Loss at index: 0, Loss: 43904584.0
Lowest Loss at index: 266, Loss: 43546168.0
Epoch [82/100], Step [100/100], Loss: 43702828.0


 83%|████████▎ | 83/100 [36:37<07:43, 27.25s/it]

Highest Loss at index: 0, Loss: 43606400.0
Lowest Loss at index: 266, Loss: 43249128.0
Epoch [83/100], Step [100/100], Loss: 43414100.0


 84%|████████▍ | 84/100 [37:04<07:14, 27.17s/it]

Highest Loss at index: 0, Loss: 43325732.0
Lowest Loss at index: 266, Loss: 42969628.0
Epoch [84/100], Step [100/100], Loss: 43167496.0


 85%|████████▌ | 85/100 [37:31<06:45, 27.01s/it]

Highest Loss at index: 0, Loss: 43062576.0
Lowest Loss at index: 266, Loss: 42707564.0
Epoch [85/100], Step [100/100], Loss: 42989568.0


 86%|████████▌ | 86/100 [37:57<06:17, 26.97s/it]

Highest Loss at index: 0, Loss: 42816248.0
Lowest Loss at index: 266, Loss: 42462396.0
Epoch [86/100], Step [100/100], Loss: 42746268.0


 87%|████████▋ | 87/100 [38:24<05:47, 26.74s/it]

Highest Loss at index: 0, Loss: 42587140.0
Lowest Loss at index: 266, Loss: 42234148.0
Epoch [87/100], Step [100/100], Loss: 42449008.0


 88%|████████▊ | 88/100 [38:49<05:17, 26.44s/it]

Highest Loss at index: 0, Loss: 42375300.0
Lowest Loss at index: 266, Loss: 42023100.0
Epoch [88/100], Step [100/100], Loss: 42304544.0


 89%|████████▉ | 89/100 [39:15<04:47, 26.17s/it]

Highest Loss at index: 0, Loss: 42180652.0
Lowest Loss at index: 266, Loss: 41829260.0
Epoch [89/100], Step [100/100], Loss: 42110244.0


 90%|█████████ | 90/100 [39:41<04:19, 25.99s/it]

Highest Loss at index: 0, Loss: 42003188.0
Lowest Loss at index: 266, Loss: 41652416.0
Epoch [90/100], Step [100/100], Loss: 41799016.0


 91%|█████████ | 91/100 [40:06<03:53, 25.93s/it]

Highest Loss at index: 0, Loss: 41842828.0
Lowest Loss at index: 266, Loss: 41492840.0
Epoch [91/100], Step [100/100], Loss: 41704908.0


 92%|█████████▏| 92/100 [40:32<03:26, 25.85s/it]

Highest Loss at index: 0, Loss: 41699276.0
Lowest Loss at index: 266, Loss: 41349932.0
Epoch [92/100], Step [100/100], Loss: 41474848.0


 93%|█████████▎| 93/100 [40:57<03:00, 25.74s/it]

Highest Loss at index: 0, Loss: 41572692.0
Lowest Loss at index: 266, Loss: 41223828.0
Epoch [93/100], Step [100/100], Loss: 41360136.0


 94%|█████████▍| 94/100 [41:23<02:34, 25.68s/it]

Highest Loss at index: 0, Loss: 41463168.0
Lowest Loss at index: 266, Loss: 41114780.0
Epoch [94/100], Step [100/100], Loss: 41320828.0


 95%|█████████▌| 95/100 [41:49<02:08, 25.66s/it]

Highest Loss at index: 0, Loss: 41370648.0
Lowest Loss at index: 266, Loss: 41022648.0
Epoch [95/100], Step [100/100], Loss: 41206664.0


 96%|█████████▌| 96/100 [42:14<01:42, 25.67s/it]

Highest Loss at index: 0, Loss: 41295140.0
Lowest Loss at index: 266, Loss: 40947464.0
Epoch [96/100], Step [100/100], Loss: 41165936.0


 97%|█████████▋| 97/100 [42:40<01:16, 25.62s/it]

Highest Loss at index: 0, Loss: 41236668.0
Lowest Loss at index: 266, Loss: 40889204.0
Epoch [97/100], Step [100/100], Loss: 41170000.0


 98%|█████████▊| 98/100 [43:05<00:51, 25.61s/it]

Highest Loss at index: 0, Loss: 41194104.0
Lowest Loss at index: 266, Loss: 40846812.0
Epoch [98/100], Step [100/100], Loss: 41041440.0


 99%|█████████▉| 99/100 [43:31<00:25, 25.54s/it]

Highest Loss at index: 0, Loss: 41168480.0
Lowest Loss at index: 266, Loss: 40821296.0
Epoch [99/100], Step [100/100], Loss: 41097188.0


100%|██████████| 100/100 [43:56<00:00, 26.37s/it]

Highest Loss at index: 0, Loss: 41159860.0
Lowest Loss at index: 266, Loss: 40812708.0
Epoch [100/100], Step [100/100], Loss: 40955736.0



