In [1]:
import librosa
import pandas as pd
import torch.optim as optim
import numpy as np
import csv
from data_load import TextTransform
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
#from decoder import GreedyDecoder
from torch.functional import F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import torch.nn as nn
import librosa
from librosa.core import stft, magphase
from glob import glob
from torch import autograd
import csv
from data_load import CodeSwitchDataset
import zipfile

  for doc in docs:


In [2]:
def pad(wav, trans, lang):
    if lang == "Gujarati":
        max_len = 0
    elif lang == "Telugu":
        max_len = 529862
    elif lang == 'Tamil':
        max_len = 0
    else:
        raise Exception("Check Language")

    while len(wav) < max_len:
        diff = max_len - len(wav)
        ext = wav[:diff]
        wav = np.append(wav, wav[:diff])
        ratio = int(len(trans)*diff/len(wav))
        trans +=trans[:ratio]
    return wav, trans

def preprocess(data):
    #print(data)
    inputs = []
    labels = []
    input_lengths = []
    label_lengths = []
    
    for (wav, sr, trans, lang) in data:
        wav, trans  = pad(wav, trans, lang)
        out = stft(wav, win_length=int(sr*0.02), hop_length=int(sr*0.01))
        out = np.transpose(out, axes=(1, 0))

        text_transform = TextTransform()
        trans = torch.Tensor(text_transform.text_to_int(trans.lower()))

        out = magphase(out)[0]
        out = torch.from_numpy(np.array([np.log(1 + x) for x in out]))
        #print(out.shape)
        inputs.append(out)
        labels.append(trans)
        input_lengths.append(out.shape[0])
        label_lengths.append(len(trans))
    inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    #spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
    return inputs, labels, input_lengths, label_lengths

In [3]:
def GreedyDecoder(output, labels, label_lengths, blank_label=28, collapse_repeated=False):
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):
        decode = []
        targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j -1]:
                    continue
                decode.append(index.item())
        decodes.append(text_transform.int_to_text(decode))
    return decodes, targets

In [5]:
train_dataset = CodeSwitchDataset(lang = 'Telugu', mode = "train")
validation_split = .2
shuffle_dataset = True
random_seed= 42
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)


In [7]:
train_loader = DataLoader(train_dataset,
                          batch_size=16,
                          drop_last=True,
                          num_workers = 6,
                          sampler = train_sampler,
                         collate_fn = lambda x: preprocess(x))
test_loader = DataLoader(train_dataset,
                          batch_size=16,
                          drop_last=True,
                          num_workers = 6,
                          sampler = valid_sampler,
                         collate_fn = lambda x: preprocess(x))

In [8]:
device = torch.device('cuda')

In [9]:
class Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, batch_size, output_dim = 29, num_layers = 4):
        super(Model, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
        #self.drop = nn.Dropout(0.25)
        self.linear = nn.Linear(self.hidden_dim, self.output_dim)
        
    def forward(self, x):
        #print(type(self.lstm))
        #print(x.shape)
        lstm_out, hidden = self.lstm(x)
        #lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        out = self.linear(lstm_out)
        return out, hidden

In [10]:
model = Model(input_dim=1025,
              hidden_dim=512,
              batch_size=16,
              output_dim=29,
              num_layers=4)
#model.half()
model = model.to(device)
criterion = nn.CTCLoss().to(device)
epochs = 40
optimizer = optim.Adam(model.parameters())

In [11]:
torch.cuda.empty_cache()

In [10]:
data_len = len(train_loader.dataset)
pbar = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_idx, (_data) in pbar:
    #bi, wav, label = batch_idx, wav, label
    wav, labels, input_lengths, label_lengths = _data
    print(wav.shape)

  0%|          | 1/1061 [00:09<2:48:30,  9.54s/it]

torch.Size([16, 2409, 1025])


  0%|          | 2/1061 [00:10<2:02:44,  6.95s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  1%|          | 6/1061 [00:11<1:26:32,  4.92s/it]

torch.Size([16, 2409, 1025])


  1%|          | 7/1061 [00:15<1:24:12,  4.79s/it]

torch.Size([16, 2409, 1025])


  1%|          | 8/1061 [00:16<1:04:47,  3.69s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  1%|          | 12/1061 [00:17<46:14,  2.64s/it] 

torch.Size([16, 2409, 1025])


  1%|          | 13/1061 [00:22<57:58,  3.32s/it]

torch.Size([16, 2409, 1025])


  1%|▏         | 14/1061 [00:22<41:56,  2.40s/it]

torch.Size([16, 2409, 1025])


  1%|▏         | 15/1061 [00:23<31:27,  1.80s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  2%|▏         | 18/1061 [00:23<22:38,  1.30s/it]

torch.Size([16, 2409, 1025])


  2%|▏         | 19/1061 [00:28<42:29,  2.45s/it]

torch.Size([16, 2409, 1025])


  2%|▏         | 20/1061 [00:28<31:18,  1.80s/it]

torch.Size([16, 2409, 1025])


  2%|▏         | 21/1061 [00:29<25:27,  1.47s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  2%|▏         | 25/1061 [00:34<24:31,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  3%|▎         | 27/1061 [00:35<19:39,  1.14s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  3%|▎         | 31/1061 [00:40<20:15,  1.18s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  3%|▎         | 33/1061 [00:42<16:54,  1.01it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  3%|▎         | 37/1061 [00:47<18:11,  1.07s/it]

torch.Size([16, 2409, 1025])


  4%|▎         | 38/1061 [00:47<14:34,  1.17it/s]

torch.Size([16, 2409, 1025])


  4%|▍         | 42/1061 [00:48<09:56,  1.71it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  4%|▍         | 43/1061 [00:52<27:36,  1.63s/it]

torch.Size([16, 2409, 1025])


  5%|▍         | 48/1061 [00:53<18:33,  1.10s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  5%|▍         | 49/1061 [00:57<32:57,  1.95s/it]

torch.Size([16, 2409, 1025])


  5%|▍         | 50/1061 [00:59<33:30,  1.99s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  5%|▌         | 55/1061 [01:03<27:01,  1.61s/it]

torch.Size([16, 2409, 1025])


  5%|▌         | 56/1061 [01:05<28:45,  1.72s/it]

torch.Size([16, 2409, 1025])


  5%|▌         | 57/1061 [01:06<23:17,  1.39s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  6%|▌         | 61/1061 [01:10<21:16,  1.28s/it]

torch.Size([16, 2409, 1025])


  6%|▌         | 62/1061 [01:11<22:23,  1.35s/it]

torch.Size([16, 2409, 1025])


  6%|▌         | 63/1061 [01:12<17:40,  1.06s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  6%|▋         | 67/1061 [01:15<17:03,  1.03s/it]

torch.Size([16, 2409, 1025])


  6%|▋         | 68/1061 [01:17<18:07,  1.10s/it]

torch.Size([16, 2409, 1025])


  7%|▋         | 69/1061 [01:17<16:45,  1.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  7%|▋         | 72/1061 [01:18<12:44,  1.29it/s]

torch.Size([16, 2409, 1025])


  7%|▋         | 73/1061 [01:21<23:35,  1.43s/it]

torch.Size([16, 2409, 1025])


  7%|▋         | 74/1061 [01:22<22:51,  1.39s/it]

torch.Size([16, 2409, 1025])


  7%|▋         | 75/1061 [01:23<20:02,  1.22s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  7%|▋         | 78/1061 [01:24<14:57,  1.10it/s]

torch.Size([16, 2409, 1025])


  7%|▋         | 79/1061 [01:27<24:22,  1.49s/it]

torch.Size([16, 2409, 1025])


  8%|▊         | 81/1061 [01:29<18:58,  1.16s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  8%|▊         | 84/1061 [01:29<14:23,  1.13it/s]

torch.Size([16, 2409, 1025])


  8%|▊         | 85/1061 [01:33<25:47,  1.59s/it]

torch.Size([16, 2409, 1025])


  8%|▊         | 86/1061 [01:34<24:59,  1.54s/it]

torch.Size([16, 2409, 1025])


  8%|▊         | 87/1061 [01:34<18:35,  1.15s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  8%|▊         | 90/1061 [01:35<14:40,  1.10it/s]

torch.Size([16, 2409, 1025])


  9%|▊         | 91/1061 [01:38<23:41,  1.47s/it]

torch.Size([16, 2409, 1025])


  9%|▊         | 92/1061 [01:39<22:17,  1.38s/it]

torch.Size([16, 2409, 1025])


  9%|▉         | 93/1061 [01:40<19:59,  1.24s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


  9%|▉         | 96/1061 [01:41<15:34,  1.03it/s]

torch.Size([16, 2409, 1025])


  9%|▉         | 97/1061 [01:43<21:54,  1.36s/it]

torch.Size([16, 2409, 1025])


  9%|▉         | 98/1061 [01:45<23:10,  1.44s/it]

torch.Size([16, 2409, 1025])


  9%|▉         | 99/1061 [01:46<19:05,  1.19s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 10%|▉         | 102/1061 [01:47<15:03,  1.06it/s]

torch.Size([16, 2409, 1025])


 10%|▉         | 103/1061 [01:49<22:19,  1.40s/it]

torch.Size([16, 2409, 1025])


 10%|▉         | 104/1061 [01:51<22:30,  1.41s/it]

torch.Size([16, 2409, 1025])


 10%|▉         | 105/1061 [01:52<19:57,  1.25s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 10%|█         | 108/1061 [01:52<15:26,  1.03it/s]

torch.Size([16, 2409, 1025])


 10%|█         | 109/1061 [01:55<21:48,  1.37s/it]

torch.Size([16, 2409, 1025])


 10%|█         | 110/1061 [01:57<24:17,  1.53s/it]

torch.Size([16, 2409, 1025])


 10%|█         | 111/1061 [01:57<20:38,  1.30s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 11%|█         | 114/1061 [01:58<15:41,  1.01it/s]

torch.Size([16, 2409, 1025])


 11%|█         | 115/1061 [02:01<25:35,  1.62s/it]

torch.Size([16, 2409, 1025])


 11%|█         | 116/1061 [02:03<23:22,  1.48s/it]

torch.Size([16, 2409, 1025])


 11%|█         | 117/1061 [02:04<21:19,  1.36s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 11%|█▏        | 120/1061 [02:04<15:31,  1.01it/s]

torch.Size([16, 2409, 1025])


 11%|█▏        | 121/1061 [02:08<27:37,  1.76s/it]

torch.Size([16, 2409, 1025])


 11%|█▏        | 122/1061 [02:09<24:01,  1.54s/it]

torch.Size([16, 2409, 1025])


 12%|█▏        | 123/1061 [02:10<23:31,  1.50s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 12%|█▏        | 126/1061 [02:10<17:01,  1.09s/it]

torch.Size([16, 2409, 1025])


 12%|█▏        | 127/1061 [02:14<26:53,  1.73s/it]

torch.Size([16, 2409, 1025])


 12%|█▏        | 128/1061 [02:15<22:58,  1.48s/it]

torch.Size([16, 2409, 1025])


 12%|█▏        | 129/1061 [02:16<22:12,  1.43s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 12%|█▏        | 132/1061 [02:16<16:27,  1.06s/it]

torch.Size([16, 2409, 1025])


 13%|█▎        | 133/1061 [02:20<28:20,  1.83s/it]

torch.Size([16, 2409, 1025])


 13%|█▎        | 134/1061 [02:21<22:59,  1.49s/it]

torch.Size([16, 2409, 1025])


 13%|█▎        | 135/1061 [02:21<19:26,  1.26s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 13%|█▎        | 138/1061 [02:23<15:26,  1.00s/it]

torch.Size([16, 2409, 1025])


 13%|█▎        | 139/1061 [02:26<25:41,  1.67s/it]

torch.Size([16, 2409, 1025])


 13%|█▎        | 140/1061 [02:27<20:41,  1.35s/it]

torch.Size([16, 2409, 1025])


 13%|█▎        | 141/1061 [02:27<17:20,  1.13s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 14%|█▎        | 144/1061 [02:29<14:08,  1.08it/s]

torch.Size([16, 2409, 1025])


 14%|█▎        | 145/1061 [02:32<26:17,  1.72s/it]

torch.Size([16, 2409, 1025])


 14%|█▍        | 146/1061 [02:33<21:14,  1.39s/it]

torch.Size([16, 2409, 1025])


 14%|█▍        | 147/1061 [02:33<16:23,  1.08s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 14%|█▍        | 150/1061 [02:35<13:42,  1.11it/s]

torch.Size([16, 2409, 1025])


 14%|█▍        | 151/1061 [02:39<28:57,  1.91s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 14%|█▍        | 153/1061 [02:39<21:21,  1.41s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 15%|█▍        | 156/1061 [02:41<17:06,  1.13s/it]

torch.Size([16, 2409, 1025])


 15%|█▍        | 157/1061 [02:45<29:31,  1.96s/it]

torch.Size([16, 2409, 1025])


 15%|█▍        | 158/1061 [02:45<21:54,  1.46s/it]

torch.Size([16, 2409, 1025])


 15%|█▍        | 159/1061 [02:46<19:59,  1.33s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 15%|█▌        | 162/1061 [02:47<15:09,  1.01s/it]

torch.Size([16, 2409, 1025])


 15%|█▌        | 163/1061 [02:51<30:02,  2.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 16%|█▌        | 165/1061 [02:52<22:49,  1.53s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 16%|█▌        | 168/1061 [02:53<17:26,  1.17s/it]

torch.Size([16, 2409, 1025])


 16%|█▌        | 170/1061 [02:57<21:01,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 16%|█▌        | 171/1061 [02:58<18:40,  1.26s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 16%|█▋        | 174/1061 [02:59<14:45,  1.00it/s]

torch.Size([16, 2409, 1025])


 16%|█▋        | 175/1061 [03:03<27:04,  1.83s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 17%|█▋        | 177/1061 [03:04<22:10,  1.50s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 17%|█▋        | 180/1061 [03:06<17:36,  1.20s/it]

torch.Size([16, 2409, 1025])


 17%|█▋        | 181/1061 [03:08<24:36,  1.68s/it]

torch.Size([16, 2409, 1025])


 17%|█▋        | 182/1061 [03:09<18:05,  1.24s/it]

torch.Size([16, 2409, 1025])


 17%|█▋        | 183/1061 [03:11<22:41,  1.55s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 18%|█▊        | 186/1061 [03:12<16:58,  1.16s/it]

torch.Size([16, 2409, 1025])


 18%|█▊        | 187/1061 [03:15<25:01,  1.72s/it]

torch.Size([16, 2409, 1025])


 18%|█▊        | 188/1061 [03:16<22:02,  1.51s/it]

torch.Size([16, 2409, 1025])


 18%|█▊        | 192/1061 [03:17<15:25,  1.07s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 18%|█▊        | 193/1061 [03:20<23:21,  1.61s/it]

torch.Size([16, 2409, 1025])


 18%|█▊        | 194/1061 [03:21<20:58,  1.45s/it]

torch.Size([16, 2409, 1025])


 18%|█▊        | 195/1061 [03:23<22:59,  1.59s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 19%|█▉        | 199/1061 [03:26<19:26,  1.35s/it]

torch.Size([16, 2409, 1025])


 19%|█▉        | 200/1061 [03:27<17:46,  1.24s/it]

torch.Size([16, 2409, 1025])


 19%|█▉        | 201/1061 [03:29<17:53,  1.25s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 19%|█▉        | 204/1061 [03:29<13:25,  1.06it/s]

torch.Size([16, 2409, 1025])


 19%|█▉        | 205/1061 [03:32<20:07,  1.41s/it]

torch.Size([16, 2409, 1025])


 19%|█▉        | 206/1061 [03:33<19:50,  1.39s/it]

torch.Size([16, 2409, 1025])


 20%|█▉        | 207/1061 [03:35<19:27,  1.37s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 20%|█▉        | 210/1061 [03:36<14:56,  1.05s/it]

torch.Size([16, 2409, 1025])


 20%|█▉        | 211/1061 [03:37<18:04,  1.28s/it]

torch.Size([16, 2409, 1025])


 20%|█▉        | 212/1061 [03:39<19:42,  1.39s/it]

torch.Size([16, 2409, 1025])


 20%|██        | 213/1061 [03:40<19:59,  1.41s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 20%|██        | 216/1061 [03:42<15:48,  1.12s/it]

torch.Size([16, 2409, 1025])


 20%|██        | 217/1061 [03:43<15:50,  1.13s/it]

torch.Size([16, 2409, 1025])


 21%|██        | 218/1061 [03:45<20:36,  1.47s/it]

torch.Size([16, 2409, 1025])


 21%|██        | 219/1061 [03:47<20:21,  1.45s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 21%|██        | 222/1061 [03:48<15:47,  1.13s/it]

torch.Size([16, 2409, 1025])


 21%|██        | 223/1061 [03:49<15:40,  1.12s/it]

torch.Size([16, 2409, 1025])


 21%|██        | 224/1061 [03:51<20:33,  1.47s/it]

torch.Size([16, 2409, 1025])


 21%|██        | 225/1061 [03:52<19:28,  1.40s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 21%|██▏       | 228/1061 [03:54<15:28,  1.11s/it]

torch.Size([16, 2409, 1025])


 22%|██▏       | 229/1061 [03:55<14:55,  1.08s/it]

torch.Size([16, 2409, 1025])


 22%|██▏       | 230/1061 [03:57<18:25,  1.33s/it]

torch.Size([16, 2409, 1025])


 22%|██▏       | 231/1061 [03:59<23:16,  1.68s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 22%|██▏       | 234/1061 [04:00<17:04,  1.24s/it]

torch.Size([16, 2409, 1025])


 22%|██▏       | 235/1061 [04:00<15:00,  1.09s/it]

torch.Size([16, 2409, 1025])


 22%|██▏       | 236/1061 [04:02<18:28,  1.34s/it]

torch.Size([16, 2409, 1025])


 22%|██▏       | 237/1061 [04:05<24:29,  1.78s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 23%|██▎       | 240/1061 [04:06<17:35,  1.29s/it]

torch.Size([16, 2409, 1025])


 23%|██▎       | 241/1061 [04:06<15:01,  1.10s/it]

torch.Size([16, 2409, 1025])


 23%|██▎       | 242/1061 [04:08<19:16,  1.41s/it]

torch.Size([16, 2409, 1025])


 23%|██▎       | 243/1061 [04:11<25:30,  1.87s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 23%|██▎       | 247/1061 [04:12<13:51,  1.02s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 23%|██▎       | 248/1061 [04:15<19:51,  1.47s/it]

torch.Size([16, 2409, 1025])


 23%|██▎       | 249/1061 [04:18<26:38,  1.97s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 24%|██▍       | 252/1061 [04:19<19:33,  1.45s/it]

torch.Size([16, 2409, 1025])


 24%|██▍       | 253/1061 [04:19<14:47,  1.10s/it]

torch.Size([16, 2409, 1025])


 24%|██▍       | 254/1061 [04:21<20:10,  1.50s/it]

torch.Size([16, 2409, 1025])


 24%|██▍       | 255/1061 [04:25<27:18,  2.03s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 24%|██▍       | 259/1061 [04:25<19:25,  1.45s/it]

torch.Size([16, 2409, 1025])


 25%|██▍       | 260/1061 [04:27<21:44,  1.63s/it]

torch.Size([16, 2409, 1025])


 25%|██▍       | 261/1061 [04:31<30:51,  2.31s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 25%|██▌       | 266/1061 [04:33<23:05,  1.74s/it]

torch.Size([16, 2409, 1025])


 25%|██▌       | 267/1061 [04:37<29:54,  2.26s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 25%|██▌       | 270/1061 [04:37<21:59,  1.67s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 26%|██▌       | 272/1061 [04:40<19:35,  1.49s/it]

torch.Size([16, 2409, 1025])


 26%|██▌       | 273/1061 [04:42<25:20,  1.93s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 26%|██▌       | 276/1061 [04:43<18:44,  1.43s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 26%|██▌       | 278/1061 [04:46<17:36,  1.35s/it]

torch.Size([16, 2409, 1025])


 26%|██▋       | 279/1061 [04:48<21:08,  1.62s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 27%|██▋       | 282/1061 [04:49<16:32,  1.27s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 27%|██▋       | 284/1061 [04:51<15:41,  1.21s/it]

torch.Size([16, 2409, 1025])


 27%|██▋       | 285/1061 [04:53<18:49,  1.46s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 27%|██▋       | 288/1061 [04:55<14:54,  1.16s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 27%|██▋       | 290/1061 [04:57<14:14,  1.11s/it]

torch.Size([16, 2409, 1025])


 27%|██▋       | 291/1061 [04:59<19:42,  1.54s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 28%|██▊       | 294/1061 [05:01<15:44,  1.23s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 28%|██▊       | 296/1061 [05:03<14:37,  1.15s/it]

torch.Size([16, 2409, 1025])


 28%|██▊       | 297/1061 [05:05<19:36,  1.54s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 28%|██▊       | 300/1061 [05:07<15:33,  1.23s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 28%|██▊       | 302/1061 [05:09<14:22,  1.14s/it]

torch.Size([16, 2409, 1025])


 29%|██▊       | 303/1061 [05:11<18:46,  1.49s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 29%|██▉       | 306/1061 [05:13<16:01,  1.27s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 29%|██▉       | 308/1061 [05:14<12:44,  1.02s/it]

torch.Size([16, 2409, 1025])


 29%|██▉       | 309/1061 [05:17<20:29,  1.63s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 29%|██▉       | 312/1061 [05:19<17:02,  1.37s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 30%|██▉       | 314/1061 [05:20<12:53,  1.04s/it]

torch.Size([16, 2409, 1025])


 30%|██▉       | 315/1061 [05:22<18:22,  1.48s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 30%|██▉       | 318/1061 [05:25<16:16,  1.31s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 30%|███       | 320/1061 [05:25<11:55,  1.04it/s]

torch.Size([16, 2409, 1025])


 30%|███       | 321/1061 [05:28<17:33,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 31%|███       | 324/1061 [05:31<15:55,  1.30s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 31%|███       | 327/1061 [05:34<14:20,  1.17s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 31%|███       | 330/1061 [05:37<14:25,  1.18s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 31%|███▏      | 333/1061 [05:40<13:09,  1.08s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 32%|███▏      | 336/1061 [05:43<12:54,  1.07s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 32%|███▏      | 339/1061 [05:46<12:33,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 32%|███▏      | 342/1061 [05:49<12:13,  1.02s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 33%|███▎      | 345/1061 [05:52<12:03,  1.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 33%|███▎      | 350/1061 [05:55<08:25,  1.41it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 33%|███▎      | 351/1061 [05:57<15:15,  1.29s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 33%|███▎      | 354/1061 [06:00<14:10,  1.20s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 34%|███▎      | 356/1061 [06:01<10:33,  1.11it/s]

torch.Size([16, 2409, 1025])


 34%|███▎      | 357/1061 [06:03<15:28,  1.32s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 34%|███▍      | 360/1061 [06:06<14:44,  1.26s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 34%|███▍      | 363/1061 [06:08<12:42,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 34%|███▍      | 366/1061 [06:12<13:12,  1.14s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 35%|███▍      | 369/1061 [06:14<11:25,  1.01it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 35%|███▌      | 372/1061 [06:18<12:31,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 35%|███▌      | 375/1061 [06:21<11:34,  1.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 36%|███▌      | 378/1061 [06:24<11:33,  1.02s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 36%|███▌      | 381/1061 [06:26<10:59,  1.03it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 36%|███▌      | 384/1061 [06:29<10:56,  1.03it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 36%|███▋      | 387/1061 [06:32<11:10,  1.01it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 37%|███▋      | 390/1061 [06:35<10:37,  1.05it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 37%|███▋      | 392/1061 [06:36<08:42,  1.28it/s]

torch.Size([16, 2409, 1025])


 37%|███▋      | 393/1061 [06:38<13:02,  1.17s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 38%|███▊      | 398/1061 [06:41<09:04,  1.22it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 38%|███▊      | 399/1061 [06:44<16:00,  1.45s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 38%|███▊      | 402/1061 [06:46<13:56,  1.27s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 38%|███▊      | 404/1061 [06:47<10:14,  1.07it/s]

torch.Size([16, 2409, 1025])


 38%|███▊      | 405/1061 [06:50<19:18,  1.77s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 38%|███▊      | 408/1061 [06:53<15:47,  1.45s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 39%|███▊      | 411/1061 [06:56<15:09,  1.40s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 39%|███▉      | 414/1061 [06:58<12:40,  1.17s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 39%|███▉      | 417/1061 [07:02<12:58,  1.21s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 40%|███▉      | 420/1061 [07:04<11:12,  1.05s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 40%|███▉      | 423/1061 [07:08<11:26,  1.08s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 40%|████      | 426/1061 [07:10<10:09,  1.04it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 40%|████      | 429/1061 [07:13<10:55,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 41%|████      | 432/1061 [07:15<09:41,  1.08it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 41%|████      | 435/1061 [07:20<11:23,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 41%|████▏     | 438/1061 [07:22<09:48,  1.06it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 41%|████▏     | 440/1061 [07:22<07:43,  1.34it/s]

torch.Size([16, 2409, 1025])


 42%|████▏     | 441/1061 [07:25<15:27,  1.50s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 42%|████▏     | 444/1061 [07:27<12:45,  1.24s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 42%|████▏     | 446/1061 [07:28<10:09,  1.01it/s]

torch.Size([16, 2409, 1025])


 42%|████▏     | 447/1061 [07:31<15:32,  1.52s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 42%|████▏     | 450/1061 [07:33<12:45,  1.25s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 43%|████▎     | 452/1061 [07:34<10:21,  1.02s/it]

torch.Size([16, 2409, 1025])


 43%|████▎     | 453/1061 [07:36<15:21,  1.52s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 43%|████▎     | 456/1061 [07:39<13:21,  1.32s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 43%|████▎     | 458/1061 [07:39<09:42,  1.03it/s]

torch.Size([16, 2409, 1025])


 43%|████▎     | 459/1061 [07:43<17:25,  1.74s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 44%|████▎     | 462/1061 [07:45<14:12,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 44%|████▎     | 464/1061 [07:46<11:09,  1.12s/it]

torch.Size([16, 2409, 1025])


 44%|████▍     | 465/1061 [07:49<16:31,  1.66s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 44%|████▍     | 468/1061 [07:51<13:20,  1.35s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 44%|████▍     | 470/1061 [07:52<11:14,  1.14s/it]

torch.Size([16, 2409, 1025])


 44%|████▍     | 471/1061 [07:55<16:09,  1.64s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 45%|████▍     | 474/1061 [07:56<12:50,  1.31s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 45%|████▍     | 476/1061 [07:58<10:48,  1.11s/it]

torch.Size([16, 2409, 1025])


 45%|████▍     | 477/1061 [08:00<16:00,  1.64s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 45%|████▌     | 480/1061 [08:02<12:48,  1.32s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 45%|████▌     | 482/1061 [08:03<10:22,  1.08s/it]

torch.Size([16, 2409, 1025])


 46%|████▌     | 483/1061 [08:06<15:28,  1.61s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 46%|████▌     | 486/1061 [08:08<12:41,  1.33s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 46%|████▌     | 488/1061 [08:09<10:32,  1.10s/it]

torch.Size([16, 2409, 1025])


 46%|████▌     | 489/1061 [08:12<14:54,  1.56s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 46%|████▋     | 492/1061 [08:14<12:16,  1.29s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 47%|████▋     | 494/1061 [08:15<10:46,  1.14s/it]

torch.Size([16, 2409, 1025])


 47%|████▋     | 495/1061 [08:17<13:23,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 47%|████▋     | 498/1061 [08:20<11:32,  1.23s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 47%|████▋     | 500/1061 [08:21<09:33,  1.02s/it]

torch.Size([16, 2409, 1025])


 47%|████▋     | 501/1061 [08:24<14:15,  1.53s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 48%|████▊     | 504/1061 [08:26<12:32,  1.35s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 48%|████▊     | 506/1061 [08:27<09:33,  1.03s/it]

torch.Size([16, 2409, 1025])


 48%|████▊     | 507/1061 [08:30<14:32,  1.58s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 48%|████▊     | 510/1061 [08:33<12:36,  1.37s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 48%|████▊     | 512/1061 [08:33<10:00,  1.09s/it]

torch.Size([16, 2409, 1025])


 48%|████▊     | 513/1061 [08:38<19:40,  2.15s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 49%|████▊     | 516/1061 [08:40<15:11,  1.67s/it]

torch.Size([16, 2409, 1025])


 49%|████▊     | 517/1061 [08:40<11:22,  1.25s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 49%|████▉     | 519/1061 [08:44<13:31,  1.50s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 49%|████▉     | 522/1061 [08:45<10:36,  1.18s/it]

torch.Size([16, 2409, 1025])


 49%|████▉     | 523/1061 [08:46<08:05,  1.11it/s]

torch.Size([16, 2409, 1025])


 49%|████▉     | 524/1061 [08:46<06:37,  1.35it/s]

torch.Size([16, 2409, 1025])


 49%|████▉     | 525/1061 [08:50<13:56,  1.56s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 50%|████▉     | 528/1061 [08:53<12:48,  1.44s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 50%|████▉     | 530/1061 [08:54<10:43,  1.21s/it]

torch.Size([16, 2409, 1025])


 50%|█████     | 531/1061 [08:56<12:29,  1.41s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 50%|█████     | 534/1061 [09:02<13:55,  1.58s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 51%|█████     | 537/1061 [09:03<10:34,  1.21s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 51%|█████     | 539/1061 [09:04<07:50,  1.11it/s]

torch.Size([16, 2409, 1025])


 51%|█████     | 540/1061 [09:08<17:22,  2.00s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 51%|█████     | 543/1061 [09:10<13:18,  1.54s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 51%|█████▏    | 545/1061 [09:10<09:41,  1.13s/it]

torch.Size([16, 2409, 1025])


 51%|█████▏    | 546/1061 [09:15<19:05,  2.23s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 52%|█████▏    | 548/1061 [09:16<14:20,  1.68s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 52%|█████▏    | 552/1061 [09:21<13:11,  1.55s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 52%|█████▏    | 554/1061 [09:21<10:16,  1.22s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 53%|█████▎    | 558/1061 [09:27<10:39,  1.27s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 53%|█████▎    | 560/1061 [09:28<08:53,  1.06s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 53%|█████▎    | 564/1061 [09:33<09:07,  1.10s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 53%|█████▎    | 566/1061 [09:35<08:31,  1.03s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 54%|█████▎    | 570/1061 [09:39<08:34,  1.05s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 54%|█████▍    | 572/1061 [09:40<07:37,  1.07it/s]

torch.Size([16, 2409, 1025])


 54%|█████▍    | 573/1061 [09:41<06:20,  1.28it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 54%|█████▍    | 576/1061 [09:45<07:24,  1.09it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 54%|█████▍    | 578/1061 [09:46<07:15,  1.11it/s]

torch.Size([16, 2409, 1025])


 55%|█████▍    | 579/1061 [09:47<05:57,  1.35it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 55%|█████▍    | 582/1061 [09:50<07:07,  1.12it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 55%|█████▌    | 584/1061 [09:52<06:30,  1.22it/s]

torch.Size([16, 2409, 1025])


 55%|█████▌    | 585/1061 [09:52<06:27,  1.23it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 55%|█████▌    | 588/1061 [09:56<07:08,  1.10it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 56%|█████▌    | 590/1061 [09:57<06:39,  1.18it/s]

torch.Size([16, 2409, 1025])


 56%|█████▌    | 591/1061 [09:58<06:56,  1.13it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 56%|█████▌    | 594/1061 [10:01<07:14,  1.07it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 56%|█████▌    | 596/1061 [10:03<06:39,  1.16it/s]

torch.Size([16, 2409, 1025])


 56%|█████▋    | 597/1061 [10:04<07:44,  1.00s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 57%|█████▋    | 600/1061 [10:07<07:45,  1.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 57%|█████▋    | 602/1061 [10:08<06:42,  1.14it/s]

torch.Size([16, 2409, 1025])


 57%|█████▋    | 603/1061 [10:10<08:33,  1.12s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 57%|█████▋    | 606/1061 [10:13<08:23,  1.11s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 57%|█████▋    | 608/1061 [10:14<07:20,  1.03it/s]

torch.Size([16, 2409, 1025])


 57%|█████▋    | 609/1061 [10:17<10:41,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 58%|█████▊    | 612/1061 [10:20<09:25,  1.26s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 58%|█████▊    | 614/1061 [10:21<08:09,  1.10s/it]

torch.Size([16, 2409, 1025])


 58%|█████▊    | 615/1061 [10:23<09:37,  1.29s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 58%|█████▊    | 618/1061 [10:26<09:15,  1.25s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 58%|█████▊    | 620/1061 [10:28<08:17,  1.13s/it]

torch.Size([16, 2409, 1025])


 59%|█████▊    | 621/1061 [10:30<09:38,  1.32s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 59%|█████▉    | 624/1061 [10:33<08:54,  1.22s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 59%|█████▉    | 626/1061 [10:34<08:05,  1.12s/it]

torch.Size([16, 2409, 1025])


 59%|█████▉    | 627/1061 [10:35<07:32,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 59%|█████▉    | 630/1061 [10:38<07:30,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 60%|█████▉    | 632/1061 [10:40<06:57,  1.03it/s]

torch.Size([16, 2409, 1025])


 60%|█████▉    | 633/1061 [10:41<06:40,  1.07it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 60%|█████▉    | 636/1061 [10:44<06:49,  1.04it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 60%|██████    | 638/1061 [10:45<06:05,  1.16it/s]

torch.Size([16, 2409, 1025])


 60%|██████    | 639/1061 [10:47<07:26,  1.06s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 61%|██████    | 642/1061 [10:50<07:31,  1.08s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 61%|██████    | 644/1061 [10:51<06:24,  1.09it/s]

torch.Size([16, 2409, 1025])


 61%|██████    | 645/1061 [10:53<08:01,  1.16s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 61%|██████    | 647/1061 [10:54<06:10,  1.12it/s]

torch.Size([16, 2409, 1025])


 61%|██████    | 648/1061 [10:56<10:03,  1.46s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 61%|██████▏   | 650/1061 [10:57<08:01,  1.17s/it]

torch.Size([16, 2409, 1025])


 61%|██████▏   | 651/1061 [10:59<08:47,  1.29s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 62%|██████▏   | 653/1061 [11:00<07:02,  1.04s/it]

torch.Size([16, 2409, 1025])


 62%|██████▏   | 654/1061 [11:04<13:50,  2.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 62%|██████▏   | 656/1061 [11:04<09:53,  1.46s/it]

torch.Size([16, 2409, 1025])


 62%|██████▏   | 657/1061 [11:06<10:15,  1.52s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 62%|██████▏   | 659/1061 [11:06<07:25,  1.11s/it]

torch.Size([16, 2409, 1025])


 62%|██████▏   | 660/1061 [11:10<11:45,  1.76s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 62%|██████▏   | 662/1061 [11:10<08:31,  1.28s/it]

torch.Size([16, 2409, 1025])


 62%|██████▏   | 663/1061 [11:12<10:17,  1.55s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 63%|██████▎   | 668/1061 [11:16<06:41,  1.02s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 63%|██████▎   | 669/1061 [11:18<09:15,  1.42s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 64%|██████▎   | 674/1061 [11:22<06:16,  1.03it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 64%|██████▎   | 675/1061 [11:24<08:39,  1.35s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 64%|██████▍   | 680/1061 [11:28<05:51,  1.08it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 64%|██████▍   | 681/1061 [11:30<09:20,  1.47s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 64%|██████▍   | 684/1061 [11:34<09:00,  1.43s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 65%|██████▍   | 687/1061 [11:36<07:26,  1.19s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 65%|██████▌   | 690/1061 [11:41<08:17,  1.34s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 66%|██████▌   | 695/1061 [11:42<04:29,  1.36it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 66%|██████▌   | 696/1061 [11:48<12:52,  2.12s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 66%|██████▌   | 699/1061 [11:48<09:07,  1.51s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 66%|██████▌   | 701/1061 [11:48<06:32,  1.09s/it]

torch.Size([16, 2409, 1025])


 66%|██████▌   | 702/1061 [11:54<14:59,  2.51s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 67%|██████▋   | 707/1061 [11:54<10:28,  1.78s/it]

torch.Size([16, 2409, 1025])


 67%|██████▋   | 708/1061 [12:01<18:49,  3.20s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 67%|██████▋   | 713/1061 [12:02<13:18,  2.29s/it]

torch.Size([16, 2409, 1025])


 67%|██████▋   | 714/1061 [12:08<19:27,  3.36s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 68%|██████▊   | 719/1061 [12:09<13:44,  2.41s/it]

torch.Size([16, 2409, 1025])


 68%|██████▊   | 720/1061 [12:13<17:46,  3.13s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 68%|██████▊   | 725/1061 [12:14<12:33,  2.24s/it]

torch.Size([16, 2409, 1025])


 68%|██████▊   | 726/1061 [12:20<18:05,  3.24s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 69%|██████▉   | 731/1061 [12:21<12:46,  2.32s/it]

torch.Size([16, 2409, 1025])


 69%|██████▉   | 732/1061 [12:27<18:40,  3.41s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2808, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 69%|██████▉   | 737/1061 [12:28<13:12,  2.45s/it]

torch.Size([16, 2409, 1025])


 70%|██████▉   | 738/1061 [12:33<18:08,  3.37s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 70%|███████   | 743/1061 [12:35<12:59,  2.45s/it]

torch.Size([16, 2409, 1025])


 70%|███████   | 744/1061 [12:39<16:18,  3.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 71%|███████   | 749/1061 [12:40<11:34,  2.22s/it]

torch.Size([16, 2409, 1025])


 71%|███████   | 753/1061 [12:45<10:18,  2.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 71%|███████   | 755/1061 [12:46<08:18,  1.63s/it]

torch.Size([16, 2409, 1025])


 71%|███████▏  | 756/1061 [12:51<12:46,  2.51s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 72%|███████▏  | 761/1061 [12:52<09:05,  1.82s/it]

torch.Size([16, 2409, 1025])


 72%|███████▏  | 762/1061 [12:57<13:55,  2.79s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 72%|███████▏  | 767/1061 [12:58<10:01,  2.05s/it]

torch.Size([16, 2409, 1025])


 72%|███████▏  | 768/1061 [13:02<12:50,  2.63s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 73%|███████▎  | 771/1061 [13:03<09:08,  1.89s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 73%|███████▎  | 773/1061 [13:04<07:18,  1.52s/it]

torch.Size([16, 2409, 1025])


 73%|███████▎  | 774/1061 [13:08<11:07,  2.33s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 73%|███████▎  | 777/1061 [13:09<08:05,  1.71s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 73%|███████▎  | 779/1061 [13:10<05:58,  1.27s/it]

torch.Size([16, 2409, 1025])


 74%|███████▎  | 780/1061 [13:14<09:49,  2.10s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 74%|███████▍  | 783/1061 [13:15<07:18,  1.58s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 74%|███████▍  | 785/1061 [13:16<05:41,  1.24s/it]

torch.Size([16, 2409, 1025])


 74%|███████▍  | 786/1061 [13:20<09:33,  2.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 74%|███████▍  | 789/1061 [13:21<07:15,  1.60s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 75%|███████▍  | 791/1061 [13:22<05:17,  1.18s/it]

torch.Size([16, 2409, 1025])


 75%|███████▍  | 792/1061 [13:26<09:03,  2.02s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 75%|███████▍  | 795/1061 [13:27<06:56,  1.57s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 75%|███████▌  | 797/1061 [13:28<05:05,  1.16s/it]

torch.Size([16, 2409, 1025])


 75%|███████▌  | 798/1061 [13:32<09:04,  2.07s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 75%|███████▌  | 801/1061 [13:33<06:50,  1.58s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 76%|███████▌  | 803/1061 [13:34<05:06,  1.19s/it]

torch.Size([16, 2409, 1025])


 76%|███████▌  | 804/1061 [13:38<09:25,  2.20s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 76%|███████▌  | 807/1061 [13:40<07:07,  1.68s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 76%|███████▋  | 810/1061 [13:45<07:01,  1.68s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 77%|███████▋  | 813/1061 [13:45<05:09,  1.25s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 77%|███████▋  | 816/1061 [13:51<05:45,  1.41s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 77%|███████▋  | 819/1061 [13:51<04:08,  1.03s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 77%|███████▋  | 822/1061 [13:56<05:00,  1.26s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 78%|███████▊  | 825/1061 [13:58<03:54,  1.01it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 78%|███████▊  | 828/1061 [14:02<04:23,  1.13s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 78%|███████▊  | 831/1061 [14:04<03:39,  1.05it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 79%|███████▊  | 834/1061 [14:08<04:00,  1.06s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 79%|███████▉  | 837/1061 [14:09<03:29,  1.07it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 79%|███████▉  | 840/1061 [14:13<03:48,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 79%|███████▉  | 843/1061 [14:15<03:23,  1.07it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 80%|███████▉  | 846/1061 [14:19<03:43,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 80%|████████  | 849/1061 [14:21<03:11,  1.11it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 80%|████████  | 852/1061 [14:25<03:37,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 81%|████████  | 855/1061 [14:27<03:06,  1.11it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 81%|████████  | 858/1061 [14:31<03:31,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 81%|████████  | 861/1061 [14:32<02:50,  1.17it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 81%|████████▏ | 864/1061 [14:37<03:37,  1.10s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 82%|████████▏ | 867/1061 [14:38<02:54,  1.11it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 82%|████████▏ | 870/1061 [14:44<03:41,  1.16s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 82%|████████▏ | 873/1061 [14:45<02:51,  1.10it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 83%|████████▎ | 876/1061 [14:49<03:21,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 83%|████████▎ | 879/1061 [14:52<02:59,  1.01it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 83%|████████▎ | 882/1061 [14:55<03:11,  1.07s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 83%|████████▎ | 885/1061 [14:58<02:59,  1.02s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 84%|████████▎ | 888/1061 [15:02<03:13,  1.12s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 84%|████████▍ | 891/1061 [15:04<02:38,  1.07it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 84%|████████▍ | 894/1061 [15:08<03:07,  1.12s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 85%|████████▍ | 897/1061 [15:09<02:21,  1.16it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 85%|████████▍ | 900/1061 [15:14<02:54,  1.08s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 85%|████████▌ | 903/1061 [15:15<02:16,  1.16it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 85%|████████▌ | 906/1061 [15:19<02:45,  1.07s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 86%|████████▌ | 909/1061 [15:21<02:11,  1.16it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 86%|████████▌ | 912/1061 [15:25<02:36,  1.05s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 86%|████████▌ | 915/1061 [15:26<02:03,  1.18it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 87%|████████▋ | 918/1061 [15:31<02:33,  1.07s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 87%|████████▋ | 921/1061 [15:32<01:56,  1.20it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 87%|████████▋ | 924/1061 [15:37<02:23,  1.05s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 87%|████████▋ | 927/1061 [15:38<02:00,  1.11it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 88%|████████▊ | 930/1061 [15:43<02:22,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 88%|████████▊ | 933/1061 [15:44<01:52,  1.13it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 88%|████████▊ | 936/1061 [15:49<02:18,  1.11s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 89%|████████▊ | 939/1061 [15:50<01:50,  1.11it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 89%|████████▉ | 942/1061 [15:55<02:09,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 89%|████████▉ | 945/1061 [15:56<01:42,  1.13it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 89%|████████▉ | 948/1061 [16:00<01:59,  1.06s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 90%|████████▉ | 951/1061 [16:01<01:33,  1.17it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 90%|████████▉ | 954/1061 [16:06<01:51,  1.04s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 90%|█████████ | 957/1061 [16:07<01:27,  1.19it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 90%|█████████ | 960/1061 [16:12<01:46,  1.06s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 91%|█████████ | 963/1061 [16:13<01:24,  1.16it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 91%|█████████ | 966/1061 [16:18<01:43,  1.09s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 91%|█████████▏| 969/1061 [16:19<01:21,  1.13it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 92%|█████████▏| 972/1061 [16:24<01:38,  1.11s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 92%|█████████▏| 975/1061 [16:25<01:14,  1.15it/s]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 92%|█████████▏| 977/1061 [16:25<00:56,  1.49it/s]

torch.Size([16, 2409, 1025])


 92%|█████████▏| 978/1061 [16:30<02:39,  1.92s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 93%|█████████▎| 983/1061 [16:31<01:20,  1.03s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 93%|█████████▎| 984/1061 [16:36<02:46,  2.16s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 93%|█████████▎| 987/1061 [16:37<01:58,  1.60s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 93%|█████████▎| 989/1061 [16:37<01:25,  1.19s/it]

torch.Size([16, 2409, 1025])


 93%|█████████▎| 990/1061 [16:42<02:48,  2.38s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 94%|█████████▎| 993/1061 [16:43<01:58,  1.74s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 94%|█████████▍| 995/1061 [16:44<01:27,  1.33s/it]

torch.Size([16, 2409, 1025])


 94%|█████████▍| 996/1061 [16:49<02:32,  2.35s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 94%|█████████▍| 999/1061 [16:49<01:43,  1.67s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 94%|█████████▍| 1001/1061 [16:50<01:21,  1.35s/it]

torch.Size([16, 2409, 1025])


 94%|█████████▍| 1002/1061 [16:55<02:14,  2.28s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 95%|█████████▍| 1005/1061 [16:55<01:32,  1.65s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 95%|█████████▍| 1007/1061 [16:56<01:08,  1.27s/it]

torch.Size([16, 2409, 1025])


 95%|█████████▌| 1008/1061 [17:01<02:00,  2.28s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 95%|█████████▌| 1013/1061 [17:02<01:21,  1.69s/it]

torch.Size([16, 2409, 1025])


 96%|█████████▌| 1014/1061 [17:06<01:52,  2.39s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 96%|█████████▌| 1017/1061 [17:07<01:15,  1.72s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 96%|█████████▌| 1019/1061 [17:08<00:59,  1.42s/it]

torch.Size([16, 2409, 1025])


 96%|█████████▌| 1020/1061 [17:13<01:38,  2.40s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 97%|█████████▋| 1025/1061 [17:14<01:01,  1.72s/it]

torch.Size([16, 2409, 1025])


 97%|█████████▋| 1026/1061 [17:19<01:39,  2.85s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 97%|█████████▋| 1031/1061 [17:19<01:00,  2.01s/it]

torch.Size([16, 2409, 1025])


 98%|█████████▊| 1037/1061 [17:25<00:52,  2.18s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 98%|█████████▊| 1043/1061 [17:31<00:30,  1.67s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 98%|█████████▊| 1045/1061 [17:36<00:31,  1.98s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


 99%|█████████▉| 1049/1061 [17:37<00:17,  1.42s/it]

torch.Size([16, 2409, 1025])


 99%|█████████▉| 1051/1061 [17:43<00:19,  1.92s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])


100%|██████████| 1061/1061 [17:48<00:00,  1.01s/it]

torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])
torch.Size([16, 2409, 1025])





labels

label_lengths

In [12]:
def train(model, device, train_loader, criterion, epoch, writer):
    model.train()
    data_len = len(train_loader.dataset)
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (_data) in pbar:
        #bi, wav, label = batch_idx, wav, label
        wav, labels, input_lengths, label_lengths = _data
        wav = wav.to(device)
        wav = wav.float()
        labels = labels.to(device)
        optimizer.zero_grad()
        #print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

        output, _ = model(wav)
        output = F.log_softmax(output, dim=1)
        
        output = output.transpose(0,1)
        #print(output.shape)
        
        loss = criterion(output, labels, input_lengths, label_lengths)
        #print(loss)
        loss.backward()
        
        optimizer.step()
        writer.add_scalar('Loss', loss, epoch*len(train_loader)+1)
        #writer.add_scalar('TLoss', total_loss, epoch*len(train_loader)+1)
        if batch_idx % 100 == 0 or batch_idx == data_len:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(wav), data_len,
                    100. * batch_idx / len(train_loader), loss.item()))
    if (epoch+1)%2 == 0:
        model.eval().cpu()
        ckpt_model_filename = "ckpt_epoch_" + str(epoch+1) + "_batch_id_" + str(batch_idx+1) + ".pth"
        ckpt_model_path = os.path.join("checkpoints/", ckpt_model_filename)
        torch.save(model.state_dict(), ckpt_model_path)
        model.to(device).train()

In [13]:
def test(model, device, test_loader, criterion, epoch):
    model.eval()
    training_loss, train_acc = 0, 0
    eer, total_eer = 0, 0
    with torch.no_grad():
        for batch_idx, _data in enumerate(test_loader):
            frr_far, eer = 0, 0
            inputs, labels, input_lengths, label_lengths = _data 
            inputs, labels = inputs.to(device), labels.to(device)

            output, _ = model(inputs)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=1)
            output = output.transpose(0, 1) # (time, batch, n_class)
            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)

            decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
            for j in range(len(decoded_preds)):
                if len(set(decoded_preds[j])) != len(set(decoded_targets[j])):
                    frr_far+=1
            eer = frr_far/len(decoded_preds)
            total_eer+=eer
            print("EER: ",eer)
            print("Total EER: ", total_eer)

In [14]:
for layer in model.modules():
    if isinstance(layer, nn.BatchNorm2d):
        layer.float()
writer = SummaryWriter('train_logs/')
for epoch in range(1, epochs+1):
    train(model, device, train_loader, criterion, epoch, writer)
    #test(model, device, test_loader, criterion, epoch, writer)
    
    
model.eval().cpu()
save_model_filename = "final_epoch_" + str(epoch + 1) + "_batch_id_" + str(batch_idx + 1) + ".model"
save_model_path = os.path.join("checkpoints", save_model_filename)
torch.save(model.state_dict(), save_model_path)

print("\nDone, trained model saved at", save_model_path)

  0%|          | 1/849 [00:07<1:51:27,  7.89s/it]



 12%|█▏        | 101/849 [01:52<10:15,  1.21it/s]



 24%|██▎       | 201/849 [03:38<11:56,  1.11s/it]



 35%|███▌      | 301/849 [05:23<09:49,  1.08s/it]



 37%|███▋      | 316/849 [05:39<08:30,  1.04it/s]

RuntimeError: CUDA out of memory. Tried to allocate 1.97 GiB (GPU 0; 5.79 GiB total capacity; 1.27 GiB already allocated; 1.54 GiB free; 2.63 GiB reserved in total by PyTorch)

In [None]:
len(trans)

In [None]:
trans[0]

In [None]:
wav[0].shape

In [None]:
y_train = label
x_train = wav

In [None]:
nb_train = len(x_train)
nb_features = len(x_train[0][0])

In [None]:
nb_features

In [None]:
nb_train

In [None]:
x_train_len = np.asarray([len(x_train[i]) for i in range(nb_train)])
print(x_train_len)

In [None]:
T = 189
n_class = 80
y = torch.tensor([[55, 43, 40, 62, 41, 44, 53, 40, 62, 41, 50, 53, 62, 58, 36, 54, 43, 44, 49, 42]])
output_length = torch.tensor(y.shape[1])

pred_model_idx = 79*torch.ones(T, dtype=torch.long)
pred_perf_idx = torch.cat([y[0], (n_class-1) * torch.ones(T-y.shape[1], dtype=torch.long)]) # the first idx are perfect with y, then padded with blanks
pred_model = torch.eye(n_class)[pred_model_idx].unsqueeze(1) # one-hot encoding
pred_perf = torch.eye(n_class)[pred_perf_idx].unsqueeze(1) # one-hot encoding

for input_length in [torch.tensor(y.shape[1]), torch.tensor(T)]:
    print("=============\ninput length:", input_length)
    print("perfect loss:", F.ctc_loss(F.log_softmax(pred_perf, dim=2), y, input_length, output_length, n_class-1, 'none', True))
    print("all_blank loss:", F.ctc_loss(F.log_softmax(pred_model, dim=2), y, input_length, output_length, n_class-1, 'none', True))


In [None]:
wav, sr = librosa.load('Data/PartA_Telugu/Train/Audio/000010010.wav')
df = pd.read_csv('Data/PartA_Gujarati/Train/Transcription_LT_Sequence.tsv', header=None, sep='\t')

In [None]:
pred_model.shape

In [None]:
pred_perf

In [None]:
qw = torch.rand((2,2,2))

In [None]:
qw

In [None]:
nn.Linear(2, 29)(qw).shape