<a href="https://colab.research.google.com/github/agrawal-harsh2003/AI_Gym_Trainner/blob/HPE/AI_Gym_Trainner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing necessary libraries

In [None]:
!pip install torch torchvision datasets transformers matplotlib
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm import tqdm
from transformers import ResNetConfig, ResNetModel, ResNetForImageClassification, AutoImageProcessor

Import Coco-Wholebody dataset using Hugging Face API.

In [None]:
from datasets import load_dataset
ds = load_dataset("harsh-7070/COCO-Wholebody-annotated")
print(ds['train'][0].keys())
print(ds.keys())

Supporting code to find Num of Keypoints in the dataset.

In [None]:
# def find_max_keypoints(dataset):
#     max_keypoints = 0
#     count = 0
#     for item in dataset:
#         keypoints = item['objects']['keypoints']  # Assuming the keypoints are in this field
#         num_keypoints = len(keypoints)
#         max_keypoints = max(max_keypoints, num_keypoints)
#         count += 1
#         if count % 1000 == 0:
#             print(count)

#     print(max_keypoints)

# find_max_keypoints(ds['train'])

Custom functions to help normalize the Dataset.

In [None]:
#data augmentation.
data_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
])

#Padding the keypoints to create equal length across the annotation.
def pad_keypoints(keypoints_tensor, target_length, padding_value=-1):
    current_length = keypoints_tensor.shape[0]

    if current_length < target_length:
        padding_size = target_length - current_length
        padding = torch.full((padding_size, 3), padding_value)
        padded_keypoints = torch.cat((keypoints_tensor, padding), dim=0)
    else:
        padded_keypoints = keypoints_tensor
    return padded_keypoints

def custom_collate_fn(batch):
    pixel_values_list = [item[0] for item in batch]
    keypoints_list = [item[1] for item in batch]

    pixel_values_padded = torch.stack(pixel_values_list, dim=0)
    keypoints_padded = torch.nn.utils.rnn.pad_sequence(keypoints_list, batch_first=True, padding_value=-1)

    return pixel_values_padded, keypoints_padded

Custom Dataset Class.

In [None]:
class COCOWholeBodyDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        image = item['image']
        if image.mode != 'RGB':
            image = image.convert('RGB')
        keypoints = item['objects']['keypoints']
        keypoints_tensor = torch.tensor(keypoints, dtype=torch.float32).view(-1, 3)

        #Pad the keypoints tensor
        # keypoints_tensor = pad_keypoints(keypoints_tensor, 20)
        keypoints_tensor = pad_keypoints(keypoints_tensor, 2660)

        # Process the image using the image processor
        inputs = self.image_processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze()

        if self.transform:
            pixel_values = self.transform(pixel_values)

        return pixel_values, keypoints_tensor

# Hyperparameters
num_epochs = 15
batch_size = 512
learning_rate = 0.001
num_keypoints = 2660 #20*133 (20 = max no. of annotation merged as one after pre-processing the .jsonl before uploadin to COCO Hub)

Dataset creation using pre-existing split and data loader.

In [None]:
train_dataset = COCOWholeBodyDataset(ds['train'], transform=data_transforms)
val_dataset = COCOWholeBodyDataset(ds['validation'])

#data loaders with collate function
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=custom_collate_fn)

Model Initialization and Configuration leveraging Pre-Trained model from Hugging Face Hub.

In [None]:
class ResNetForImageClassification(nn.Module):
    def __init__(self, config):
        super(ResNetForImageClassification, self).__init__()
        self.config = config
        self.resnet_backbone = torchvision.models.resnet50(pretrained=True)
        self.resnet_backbone.fc = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, num_keypoints * 2)
        )

    def forward(self, x):
        x = self.resnet_backbone(x)

        if len(x.shape) == 4:
            x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        elif len(x.shape) == 2:
            x = x.flatten(1)
        else:
            raise ValueError(f"Unexpected tensor shape: {x.shape}")

        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

config = ResNetConfig(
    num_channels=3,
    embedding_size=64,
    hidden_sizes=[256, 512, 1024, 2048],
    depths=[3, 4, 6, 3],
    layer_type="bottleneck",
    hidden_act="relu",
    num_labels = num_keypoints * 2
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetForImageClassification(config).to(device)



In [None]:
# for images, keypoints in train_loader:
#     images = images.to(device)
#     keypoints = keypoints.to(device)
#     print(images.shape)
#     print(keypoints.shape)
#     output = model(images)
#     print(output.shape)
#     break
# torch.Size([32, 3, 224, 224])
# torch.Size([32, 238, 3])
# torch.Size([32, 476])

torch.Size([32, 3, 224, 224])
torch.Size([32, 2660, 3])
torch.Size([32, 5320])


Model Training

In [None]:
import torch.optim as optim

criterion = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total_samples = 0
    for batch_idx, (images, keypoints) in enumerate(tqdm(train_loader)):
        images = images.to(device)
        keypoints = keypoints.to(device)

        keypoints = keypoints[:, :, :2]
        print(keypoints.shape, '\n')

        # keypoints = keypoints.view(images.size(0), -1)  # Flatten to [batch_size, 238 * 2]
        keypoints = keypoints.reshape(images.size(0), -1)

        output = model(images)

        mask = (keypoints != -1).float()
        loss = criterion(output, keypoints)
        masked_loss = loss * mask
        loss = masked_loss.sum() / mask.sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        running_loss += loss.item() * images.size(0)
        total_samples += images.size(0)


    avg_loss = running_loss / total_samples
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

  0%|          | 0/3697 [00:00<?, ?it/s]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 1/3697 [00:34<35:15:09, 34.34s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 2/3697 [00:57<28:39:15, 27.92s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 3/3697 [01:24<28:09:38, 27.44s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 4/3697 [01:46<26:03:26, 25.40s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 5/3697 [02:08<24:36:33, 24.00s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 6/3697 [02:31<24:08:26, 23.55s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 7/3697 [02:52<23:30:09, 22.93s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 8/3697 [03:13<22:46:56, 22.23s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 9/3697 [03:35<22:46:20, 22.23s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 10/3697 [03:57<22:31:49, 22.00s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 11/3697 [04:18<22:27:31, 21.93s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 12/3697 [04:42<22:53:33, 22.36s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 13/3697 [05:03<22:27:15, 21.94s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 14/3697 [05:25<22:29:21, 21.98s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 15/3697 [05:45<21:58:25, 21.48s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 16/3697 [06:07<21:58:23, 21.49s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 17/3697 [06:28<21:45:23, 21.28s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  0%|          | 18/3697 [06:49<21:53:22, 21.42s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 19/3697 [07:10<21:47:49, 21.33s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 20/3697 [07:32<21:53:42, 21.44s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 21/3697 [07:54<21:55:16, 21.47s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 22/3697 [08:15<21:56:15, 21.49s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 23/3697 [08:37<21:56:22, 21.50s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 24/3697 [08:57<21:38:53, 21.22s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 25/3697 [09:20<22:03:17, 21.62s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 26/3697 [09:40<21:44:40, 21.32s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 27/3697 [10:02<21:54:04, 21.48s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 28/3697 [10:28<23:17:21, 22.85s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 29/3697 [10:49<22:32:43, 22.13s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 30/3697 [11:10<22:23:47, 21.99s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 31/3697 [11:32<22:11:20, 21.79s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 32/3697 [11:54<22:25:27, 22.03s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 33/3697 [12:16<22:21:14, 21.96s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 34/3697 [12:37<22:03:11, 21.67s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 35/3697 [12:59<22:10:42, 21.80s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 36/3697 [13:19<21:42:00, 21.34s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 37/3697 [13:41<21:42:16, 21.35s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 38/3697 [14:03<21:49:52, 21.48s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 39/3697 [14:25<22:00:06, 21.65s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 40/3697 [14:47<22:03:00, 21.71s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 41/3697 [15:08<22:01:21, 21.69s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 42/3697 [15:31<22:18:16, 21.97s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 43/3697 [15:53<22:18:49, 21.98s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 44/3697 [16:15<22:22:36, 22.05s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 45/3697 [16:36<22:07:10, 21.80s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|          | 46/3697 [16:57<21:54:16, 21.60s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 47/3697 [17:20<22:06:35, 21.81s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 48/3697 [17:41<21:55:32, 21.63s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 49/3697 [18:03<22:10:27, 21.88s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 50/3697 [18:24<21:47:48, 21.52s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 51/3697 [18:45<21:41:43, 21.42s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 52/3697 [19:07<21:48:34, 21.54s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 53/3697 [19:28<21:44:47, 21.48s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 54/3697 [19:50<21:43:02, 21.46s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  1%|▏         | 55/3697 [20:10<21:24:26, 21.16s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 56/3697 [20:38<23:17:38, 23.03s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 57/3697 [20:59<22:53:13, 22.64s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 58/3697 [21:21<22:27:02, 22.21s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 59/3697 [21:42<22:13:26, 21.99s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 60/3697 [22:03<21:45:36, 21.54s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 61/3697 [22:24<21:41:10, 21.47s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 62/3697 [22:44<21:22:41, 21.17s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 63/3697 [23:06<21:30:20, 21.30s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 64/3697 [23:27<21:23:52, 21.20s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 65/3697 [23:49<21:32:05, 21.35s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 66/3697 [24:11<21:46:27, 21.59s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 67/3697 [24:31<21:30:40, 21.33s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 68/3697 [24:54<21:57:59, 21.79s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 69/3697 [25:16<21:48:01, 21.63s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 70/3697 [25:39<22:17:51, 22.13s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 71/3697 [26:00<22:02:23, 21.88s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 72/3697 [26:22<22:05:40, 21.94s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 73/3697 [26:44<22:00:35, 21.86s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 74/3697 [27:05<21:43:34, 21.59s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 75/3697 [27:27<21:50:42, 21.71s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 76/3697 [27:49<21:52:58, 21.76s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 77/3697 [28:11<21:58:58, 21.86s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 78/3697 [28:32<21:51:42, 21.75s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 79/3697 [28:54<21:43:03, 21.61s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 80/3697 [29:16<21:55:12, 21.82s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 81/3697 [29:37<21:39:29, 21.56s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 82/3697 [30:00<22:03:53, 21.97s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 83/3697 [30:22<22:05:52, 22.01s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 84/3697 [30:49<23:41:15, 23.60s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 85/3697 [31:10<22:44:06, 22.66s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 86/3697 [31:32<22:35:19, 22.52s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 87/3697 [31:53<22:02:07, 21.97s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 88/3697 [32:15<22:00:53, 21.96s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 89/3697 [32:36<21:54:41, 21.86s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 90/3697 [32:58<21:49:26, 21.78s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 91/3697 [33:20<21:55:09, 21.88s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  2%|▏         | 92/3697 [33:40<21:31:14, 21.49s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 93/3697 [34:03<21:48:39, 21.79s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 94/3697 [34:24<21:30:23, 21.49s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 95/3697 [34:46<21:36:29, 21.60s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 96/3697 [35:06<21:19:17, 21.32s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 97/3697 [35:30<22:11:16, 22.19s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 98/3697 [35:51<21:44:15, 21.74s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 99/3697 [36:13<21:47:47, 21.81s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 100/3697 [36:36<21:59:22, 22.01s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 101/3697 [36:57<21:40:28, 21.70s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 102/3697 [37:18<21:32:54, 21.58s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 103/3697 [37:39<21:16:15, 21.31s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 104/3697 [38:00<21:25:25, 21.47s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 105/3697 [38:20<20:58:38, 21.02s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 106/3697 [38:42<21:12:30, 21.26s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 107/3697 [39:03<21:05:00, 21.14s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 108/3697 [39:24<21:00:41, 21.08s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 109/3697 [39:45<20:58:56, 21.05s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 110/3697 [40:07<21:10:50, 21.26s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 111/3697 [40:28<21:18:50, 21.40s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 112/3697 [40:54<22:24:50, 22.51s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 113/3697 [41:17<22:48:34, 22.91s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 114/3697 [41:39<22:23:08, 22.49s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 115/3697 [42:00<21:59:28, 22.10s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 116/3697 [42:22<21:49:46, 21.95s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 117/3697 [42:43<21:30:21, 21.63s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 118/3697 [43:04<21:24:33, 21.53s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 119/3697 [43:24<21:01:30, 21.15s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 120/3697 [43:46<21:10:19, 21.31s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 121/3697 [44:08<21:25:17, 21.57s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 122/3697 [44:29<21:17:28, 21.44s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 123/3697 [44:51<21:26:14, 21.59s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 124/3697 [45:13<21:31:52, 21.69s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 125/3697 [45:35<21:42:11, 21.87s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 126/3697 [45:57<21:43:45, 21.91s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 127/3697 [46:18<21:29:25, 21.67s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 128/3697 [46:42<21:56:48, 22.14s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  3%|▎         | 129/3697 [47:02<21:33:17, 21.75s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 130/3697 [47:27<22:28:37, 22.69s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 131/3697 [47:49<22:13:07, 22.43s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 132/3697 [48:10<21:40:48, 21.89s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 133/3697 [48:32<21:42:24, 21.93s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 134/3697 [48:54<21:47:13, 22.01s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 135/3697 [49:15<21:32:05, 21.76s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 136/3697 [49:37<21:36:09, 21.84s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 137/3697 [50:00<21:44:39, 21.99s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▎         | 138/3697 [50:21<21:39:22, 21.91s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 139/3697 [50:42<21:15:47, 21.51s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 140/3697 [51:06<22:09:08, 22.42s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 141/3697 [51:28<21:58:17, 22.24s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 142/3697 [51:49<21:35:37, 21.87s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 143/3697 [52:11<21:26:58, 21.73s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 144/3697 [52:31<21:10:27, 21.45s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 145/3697 [52:55<21:38:04, 21.93s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 146/3697 [53:16<21:27:45, 21.76s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 147/3697 [53:37<21:24:45, 21.71s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 148/3697 [54:00<21:32:25, 21.85s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 149/3697 [54:22<21:32:42, 21.86s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 150/3697 [54:43<21:27:09, 21.77s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 151/3697 [55:04<21:14:34, 21.57s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 152/3697 [55:26<21:12:24, 21.54s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 153/3697 [55:47<21:02:48, 21.38s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 154/3697 [56:08<21:05:54, 21.44s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 155/3697 [56:30<21:13:11, 21.57s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 156/3697 [56:51<21:06:48, 21.47s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 157/3697 [57:13<21:09:56, 21.52s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 158/3697 [57:34<20:57:48, 21.32s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 159/3697 [57:56<21:05:03, 21.45s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 160/3697 [58:18<21:22:57, 21.76s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 161/3697 [58:40<21:31:20, 21.91s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 162/3697 [59:02<21:20:55, 21.74s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 163/3697 [59:24<21:35:25, 21.99s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 164/3697 [59:46<21:23:37, 21.80s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 165/3697 [1:00:07<21:09:16, 21.56s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  4%|▍         | 166/3697 [1:00:29<21:14:20, 21.65s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 167/3697 [1:00:51<21:36:09, 22.03s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 168/3697 [1:01:16<22:19:28, 22.77s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 169/3697 [1:01:38<22:06:00, 22.55s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 170/3697 [1:02:01<22:14:00, 22.69s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 171/3697 [1:02:22<21:47:53, 22.26s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 172/3697 [1:02:44<21:39:03, 22.11s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 173/3697 [1:03:05<21:22:34, 21.84s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 174/3697 [1:03:26<21:07:43, 21.59s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 175/3697 [1:03:47<20:58:42, 21.44s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 176/3697 [1:04:08<20:51:06, 21.32s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 177/3697 [1:04:31<21:09:30, 21.64s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 178/3697 [1:04:51<20:47:58, 21.28s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 179/3697 [1:05:13<20:57:20, 21.44s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 180/3697 [1:05:33<20:36:10, 21.09s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 181/3697 [1:05:55<20:43:53, 21.23s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 182/3697 [1:06:16<20:45:14, 21.26s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 183/3697 [1:06:37<20:44:53, 21.26s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▍         | 184/3697 [1:07:00<21:00:01, 21.52s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 185/3697 [1:07:23<21:34:33, 22.12s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 186/3697 [1:07:45<21:27:46, 22.01s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 187/3697 [1:08:05<21:00:38, 21.55s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 188/3697 [1:08:28<21:13:12, 21.77s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 189/3697 [1:08:50<21:25:10, 21.98s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 190/3697 [1:09:11<21:00:38, 21.57s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 191/3697 [1:09:33<21:18:22, 21.88s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 192/3697 [1:09:54<20:57:40, 21.53s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 193/3697 [1:10:16<21:04:32, 21.65s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 194/3697 [1:10:38<21:20:51, 21.94s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 195/3697 [1:11:00<21:05:51, 21.69s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 196/3697 [1:11:27<22:42:54, 23.36s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 197/3697 [1:11:49<22:16:28, 22.91s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 198/3697 [1:12:12<22:17:34, 22.94s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 199/3697 [1:12:33<21:50:05, 22.47s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 200/3697 [1:12:55<21:35:02, 22.22s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 201/3697 [1:13:16<21:23:31, 22.03s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 202/3697 [1:13:38<21:09:38, 21.80s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  5%|▌         | 203/3697 [1:14:00<21:26:33, 22.09s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 204/3697 [1:14:23<21:34:16, 22.23s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 205/3697 [1:14:44<21:07:57, 21.79s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 206/3697 [1:15:06<21:20:21, 22.01s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 207/3697 [1:15:27<20:57:26, 21.62s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 208/3697 [1:15:48<20:51:31, 21.52s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 209/3697 [1:16:09<20:34:57, 21.24s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 210/3697 [1:16:32<21:03:00, 21.73s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 211/3697 [1:16:54<21:13:06, 21.91s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 212/3697 [1:17:14<20:42:54, 21.40s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 213/3697 [1:17:37<21:12:42, 21.92s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 214/3697 [1:17:58<20:49:57, 21.53s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 215/3697 [1:18:20<20:55:47, 21.64s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 216/3697 [1:18:41<20:52:11, 21.58s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 217/3697 [1:19:04<21:07:04, 21.85s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 218/3697 [1:19:26<21:18:44, 22.05s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 219/3697 [1:19:48<21:12:08, 21.95s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 220/3697 [1:20:09<21:04:32, 21.82s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 221/3697 [1:20:30<20:38:39, 21.38s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 222/3697 [1:20:53<21:04:39, 21.84s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 223/3697 [1:21:14<20:50:18, 21.59s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 224/3697 [1:21:40<22:03:42, 22.87s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 225/3697 [1:22:01<21:35:43, 22.39s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 226/3697 [1:22:23<21:29:48, 22.30s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 227/3697 [1:22:44<21:00:25, 21.79s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 228/3697 [1:23:05<20:46:09, 21.55s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 229/3697 [1:23:26<20:34:58, 21.37s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 230/3697 [1:23:47<20:36:17, 21.40s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▌         | 231/3697 [1:24:09<20:46:55, 21.59s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 232/3697 [1:24:31<20:54:18, 21.72s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 233/3697 [1:24:53<20:54:28, 21.73s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 234/3697 [1:25:13<20:34:28, 21.39s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 235/3697 [1:25:36<20:52:48, 21.71s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 236/3697 [1:25:57<20:37:12, 21.45s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 237/3697 [1:26:19<20:51:20, 21.70s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 238/3697 [1:26:42<21:08:50, 22.01s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 239/3697 [1:27:03<21:01:01, 21.88s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  6%|▋         | 240/3697 [1:27:25<21:05:05, 21.96s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  7%|▋         | 241/3697 [1:27:46<20:48:31, 21.68s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  7%|▋         | 242/3697 [1:28:08<20:45:56, 21.64s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  7%|▋         | 243/3697 [1:28:30<21:00:23, 21.89s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
torch.Size([32, 5320])


  7%|▋         | 244/3697 [1:28:51<20:34:34, 21.45s/it]

torch.Size([32, 2660, 2]) 

torch.Size([32, 5320])
