In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from src.dataset2 import MIR1K2
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchaudio.models import HDemucs
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from collections import defaultdict
from src.constants import SAMPLE_RATE
from src.loss import bce
from src.utils import to_local_average_cents
from mir_eval.melody import (
    raw_pitch_accuracy,
    to_cent_voicing,
    raw_chroma_accuracy,
    overall_accuracy,
)

In [2]:
HOP_LENGTH = 20
SEQ_L = 2.56
BATCH_SIZE = 8
SAMPLE_RATE = 16000

train_dataset = MIR1K2("dataset", HOP_LENGTH, SEQ_L, groups= ["train"])
validation_dataset = MIR1K2('dataset', HOP_LENGTH, SEQ_L, groups= ["test"])

Loading 1 group of MIR1K2 at dataset


Loading group train: 100%|██████████| 800/800 [00:07<00:00, 101.93it/s]


Loading 1 group of MIR1K2 at dataset


Loading group test: 100%|██████████| 200/200 [00:01<00:00, 109.11it/s]


In [3]:
# import dataloader
data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [4]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

In [5]:
def conv_layer(input_channels, output_channels, kernel_size, stride):
    return nn.Sequential(
        nn.Conv1d(input_channels, output_channels, kernel_size, stride),
        nn.BatchNorm1d(output_channels),
        nn.ReLU(),
        nn.MaxPool1d(2, 2),
    )


class E2E_Demucs(nn.Module):
    def __init__(self):
        super(E2E_Demucs, self).__init__()
        self.hdemucs = HDemucs(sources=["vocals"], nfft=2048, depth=6, audio_channels=1)
        self.bigru = nn.GRU(1024, 256, 1, batch_first=True, bidirectional=True).cuda()
        self.conv = nn.Sequential(
            conv_layer(1, 512, 90, 1),
            conv_layer(512, 64, 12, 1),
            conv_layer(64, 128, 12, 1),
            conv_layer(128, 256, 12, 1),
        )
        self.fc = nn.Sequential(nn.Linear(512, 360), nn.Dropout(0.25), nn.Sigmoid())

    def forward(self, audio):
        x = self.hdemucs(audio)
        x = x.squeeze(1).view(x.size(0), -1, 320)
        B, S, T = x.shape
        x = self.conv(x.view(B * S, 1, T)).view(B, S, -1)
        x, _ = self.bigru(x)
        x = self.fc(x)
        return x

In [6]:
def evaluate2(dataset, model, hop_length, device, pitch_th=0.0):
    metrics = defaultdict(list)
    for data in dataset:
        audio = data['audio'].view(1,1,-1).to(device)
        pitch_label = data["pitch"].to(device)
        pitch_pred = model(audio).squeeze()

        loss = bce(pitch_pred, pitch_label)
        metrics["loss"].append(loss.item())

        cents_pred = to_local_average_cents(pitch_pred.cpu().numpy(), None, pitch_th)

        cents_label = to_local_average_cents(pitch_label.cpu().numpy(), None, pitch_th)

        freq_pred = np.array(
            [
                10 * (2 ** (cent_pred / 1200)) if cent_pred else 0
                for cent_pred in cents_pred
            ]
        )
        freq = np.array(
            [10 * (2 ** (cent / 1200)) if cent else 0 for cent in cents_label]
        )

        time_slice = np.array([i * hop_length / 1000 for i in range(len(cents_label))])
        ref_v, ref_c, est_v, est_c = to_cent_voicing(
            time_slice, freq, time_slice, freq_pred
        )

        rpa = raw_pitch_accuracy(ref_v, ref_c, est_v, est_c)
        rca = raw_chroma_accuracy(ref_v, ref_c, est_v, est_c)
        oa = overall_accuracy(ref_v, ref_c, est_v, est_c)
        metrics["RPA"].append(rpa)
        metrics["RCA"].append(rca)
        metrics["OA"].append(oa)
        # if rpa < 0.9:
        print(data["file"], ":\t", rpa, "\t", oa)

    return metrics


In [8]:
from src.loss import FL


log_dir = 'runs/E2E2'

learning_rate = 5e-4

validation_interval = len(data_loader) 
iterations = len(data_loader) * 100
learning_rate_decay_steps = len(data_loader) * 5
learning_rate_decay_rate = 0.98

os.makedirs(log_dir, exist_ok=True)
writer= SummaryWriter(log_dir)

model = E2E_Demucs().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)

loop = tqdm(range(1, iterations + 1))
RPA, RCA, OA, VFA, VR, it = 0, 0, 0, 0, 0, 0

for i, data in zip(loop, cycle(data_loader)):
    audio = data["audio"].unsqueeze(1).cuda()
    pitch_label = data["pitch"].cuda()
    pitch_pred = model(audio)
    
    loss = FL(pitch_pred, pitch_label, 6, 0)

    print(i, end="\t")
    print("loss_total:", loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    writer.add_scalar("loss/loss_pitch", loss.item(), global_step=i)

    if i % validation_interval == 0:
        model.eval()
        with torch.no_grad():
            metrics = evaluate2(validation_dataset, model, HOP_LENGTH, 'cuda')
            for key, value in metrics.items():
                writer.add_scalar(
                    "stage_pitch/" + key, np.mean(value), global_step=i
                )
            rpa = np.mean(metrics["RPA"])
            rca = np.mean(metrics["RCA"])
            oa = np.mean(metrics["OA"])
            if rpa >= RPA:
                RPA, RCA, OA, it = rpa, rca, oa, i
                with open(os.path.join(log_dir, "result.txt"), "a") as f:
                    f.write(str(i) + "\t")
                    f.write(str(RPA) + "\t")
                    f.write(str(RCA) + "\t")
                    f.write(str(OA) + "\t")
                torch.save(model, os.path.join(log_dir, f"model.pt"))
        model.train()

    if i - it > len(data_loader) * 10:
        break



1	loss_total: 0.7013555765151978


  0%|          | 2/36400 [00:01<4:53:47,  2.06it/s]

2	loss_total: 0.6951654553413391


  0%|          | 3/36400 [00:01<3:41:49,  2.73it/s]

3	loss_total: 0.6892274618148804


  0%|          | 4/36400 [00:01<3:09:04,  3.21it/s]

4	loss_total: 0.6785494089126587


  0%|          | 5/36400 [00:01<2:50:19,  3.56it/s]

5	loss_total: 0.6543348431587219


  0%|          | 6/36400 [00:01<2:39:13,  3.81it/s]

6	loss_total: 0.5935062766075134


  0%|          | 7/36400 [00:02<2:31:47,  4.00it/s]

7	loss_total: 0.49129757285118103


  0%|          | 8/36400 [00:02<2:27:25,  4.11it/s]

8	loss_total: 0.4085785150527954


  0%|          | 9/36400 [00:02<2:24:24,  4.20it/s]

9	loss_total: 0.3307751715183258


  0%|          | 10/36400 [00:02<2:22:41,  4.25it/s]

10	loss_total: 0.2902195155620575


  0%|          | 11/36400 [00:03<2:24:41,  4.19it/s]

11	loss_total: 0.26271405816078186


  0%|          | 12/36400 [00:03<2:22:10,  4.27it/s]

12	loss_total: 0.2472187578678131


  0%|          | 13/36400 [00:03<2:21:04,  4.30it/s]

13	loss_total: 0.24329152703285217


  0%|          | 14/36400 [00:03<2:20:32,  4.32it/s]

14	loss_total: 0.23170991241931915


  0%|          | 15/36400 [00:04<2:19:23,  4.35it/s]

15	loss_total: 0.23389272391796112


  0%|          | 16/36400 [00:04<2:18:19,  4.38it/s]

16	loss_total: 0.22889964282512665


  0%|          | 17/36400 [00:04<2:18:34,  4.38it/s]

17	loss_total: 0.22754669189453125


  0%|          | 18/36400 [00:04<2:17:35,  4.41it/s]

18	loss_total: 0.2206094115972519


  0%|          | 19/36400 [00:04<2:17:36,  4.41it/s]

19	loss_total: 0.2205173820257187


  0%|          | 20/36400 [00:05<2:17:25,  4.41it/s]

20	loss_total: 0.2209448367357254


  0%|          | 21/36400 [00:05<2:17:13,  4.42it/s]

21	loss_total: 0.22096893191337585


  0%|          | 22/36400 [00:05<2:25:53,  4.16it/s]

22	loss_total: 0.2207331359386444


  0%|          | 23/36400 [00:05<2:23:25,  4.23it/s]

23	loss_total: 0.21959078311920166


  0%|          | 24/36400 [00:06<2:20:57,  4.30it/s]

24	loss_total: 0.22147010266780853


  0%|          | 25/36400 [00:06<2:19:46,  4.34it/s]

25	loss_total: 0.21507060527801514


  0%|          | 26/36400 [00:06<2:18:51,  4.37it/s]

26	loss_total: 0.20913442969322205


  0%|          | 27/36400 [00:06<2:17:23,  4.41it/s]

27	loss_total: 0.21200519800186157


  0%|          | 28/36400 [00:07<2:16:49,  4.43it/s]

28	loss_total: 0.2175069898366928


  0%|          | 29/36400 [00:07<2:16:20,  4.45it/s]

29	loss_total: 0.20978011190891266


  0%|          | 30/36400 [00:07<2:15:48,  4.46it/s]

30	loss_total: 0.21844947338104248


  0%|          | 31/36400 [00:07<2:15:36,  4.47it/s]

31	loss_total: 0.20991675555706024


  0%|          | 32/36400 [00:07<2:16:04,  4.45it/s]

32	loss_total: 0.21178831160068512


  0%|          | 33/36400 [00:08<2:24:35,  4.19it/s]

33	loss_total: 0.2149837762117386


  0%|          | 34/36400 [00:08<2:21:51,  4.27it/s]

34	loss_total: 0.20883840322494507


  0%|          | 35/36400 [00:08<2:20:01,  4.33it/s]

35	loss_total: 0.20488111674785614


  0%|          | 36/36400 [00:08<2:18:39,  4.37it/s]

36	loss_total: 0.21999970078468323


  0%|          | 37/36400 [00:09<2:17:40,  4.40it/s]

37	loss_total: 0.21030880510807037


  0%|          | 38/36400 [00:09<2:17:27,  4.41it/s]

38	loss_total: 0.21991612017154694


  0%|          | 39/36400 [00:09<2:16:29,  4.44it/s]

39	loss_total: 0.20716409385204315


  0%|          | 40/36400 [00:09<2:15:53,  4.46it/s]

40	loss_total: 0.21482816338539124


  0%|          | 41/36400 [00:09<2:15:50,  4.46it/s]

41	loss_total: 0.21360991895198822


  0%|          | 42/36400 [00:10<2:16:36,  4.44it/s]

42	loss_total: 0.2112683206796646


  0%|          | 43/36400 [00:10<2:16:37,  4.43it/s]

43	loss_total: 0.21501031517982483


  0%|          | 44/36400 [00:10<2:17:15,  4.41it/s]

44	loss_total: 0.2171206772327423


  0%|          | 45/36400 [00:10<2:17:20,  4.41it/s]

45	loss_total: 0.2089376300573349


  0%|          | 46/36400 [00:11<2:16:31,  4.44it/s]

46	loss_total: 0.20949317514896393


  0%|          | 47/36400 [00:11<2:16:31,  4.44it/s]

47	loss_total: 0.21379660069942474


  0%|          | 48/36400 [00:11<2:17:33,  4.40it/s]

48	loss_total: 0.208268940448761


  0%|          | 49/36400 [00:11<2:16:38,  4.43it/s]

49	loss_total: 0.20543137192726135


  0%|          | 50/36400 [00:12<2:17:59,  4.39it/s]

50	loss_total: 0.20982670783996582


  0%|          | 51/36400 [00:12<2:18:23,  4.38it/s]

51	loss_total: 0.2067224383354187


  0%|          | 52/36400 [00:12<2:17:17,  4.41it/s]

52	loss_total: 0.21003887057304382


  0%|          | 53/36400 [00:12<2:16:59,  4.42it/s]

53	loss_total: 0.2132163643836975


  0%|          | 54/36400 [00:12<2:16:36,  4.43it/s]

54	loss_total: 0.2111053764820099


  0%|          | 55/36400 [00:13<2:15:46,  4.46it/s]

55	loss_total: 0.21204796433448792


  0%|          | 56/36400 [00:13<2:16:03,  4.45it/s]

56	loss_total: 0.2130892127752304


  0%|          | 57/36400 [00:13<2:16:23,  4.44it/s]

57	loss_total: 0.21794137358665466


  0%|          | 58/36400 [00:13<2:15:59,  4.45it/s]

58	loss_total: 0.21413911879062653


  0%|          | 59/36400 [00:14<2:17:31,  4.40it/s]

59	loss_total: 0.2195499837398529


  0%|          | 60/36400 [00:14<2:17:30,  4.40it/s]

60	loss_total: 0.217513769865036


  0%|          | 61/36400 [00:14<2:29:21,  4.06it/s]

61	loss_total: 0.21424677968025208


  0%|          | 62/36400 [00:14<2:26:14,  4.14it/s]

62	loss_total: 0.20913365483283997


  0%|          | 63/36400 [00:15<2:22:52,  4.24it/s]

63	loss_total: 0.2116137593984604


  0%|          | 64/36400 [00:15<2:21:26,  4.28it/s]

64	loss_total: 0.21227112412452698


  0%|          | 65/36400 [00:15<2:20:32,  4.31it/s]

65	loss_total: 0.21278351545333862


  0%|          | 66/36400 [00:15<2:19:33,  4.34it/s]

66	loss_total: 0.21248012781143188


  0%|          | 67/36400 [00:15<2:18:11,  4.38it/s]

67	loss_total: 0.21248102188110352


  0%|          | 68/36400 [00:16<2:17:36,  4.40it/s]

68	loss_total: 0.21368908882141113


  0%|          | 69/36400 [00:16<2:17:44,  4.40it/s]

69	loss_total: 0.21260716021060944


  0%|          | 70/36400 [00:16<2:17:28,  4.40it/s]

70	loss_total: 0.21163427829742432


  0%|          | 71/36400 [00:16<2:22:56,  4.24it/s]

71	loss_total: 0.2129851132631302


  0%|          | 72/36400 [00:17<2:30:22,  4.03it/s]

72	loss_total: 0.21390707790851593


  0%|          | 73/36400 [00:17<2:25:34,  4.16it/s]

73	loss_total: 0.21084874868392944


  0%|          | 74/36400 [00:17<2:22:57,  4.24it/s]

74	loss_total: 0.21144568920135498


  0%|          | 75/36400 [00:17<2:20:54,  4.30it/s]

75	loss_total: 0.21193930506706238


  0%|          | 76/36400 [00:18<2:20:24,  4.31it/s]

76	loss_total: 0.21007156372070312


  0%|          | 77/36400 [00:18<2:18:36,  4.37it/s]

77	loss_total: 0.20308202505111694


  0%|          | 78/36400 [00:18<2:17:31,  4.40it/s]

78	loss_total: 0.21521425247192383


  0%|          | 79/36400 [00:18<2:16:56,  4.42it/s]

79	loss_total: 0.21478265523910522


  0%|          | 80/36400 [00:18<2:16:34,  4.43it/s]

80	loss_total: 0.21424749493598938


  0%|          | 81/36400 [00:19<2:15:59,  4.45it/s]

81	loss_total: 0.2083844691514969


  0%|          | 82/36400 [00:19<2:20:13,  4.32it/s]

82	loss_total: 0.21319444477558136


  0%|          | 83/36400 [00:19<2:18:38,  4.37it/s]

83	loss_total: 0.20820289850234985


  0%|          | 84/36400 [00:19<2:18:11,  4.38it/s]

84	loss_total: 0.2138470560312271


  0%|          | 85/36400 [00:20<2:17:24,  4.40it/s]

85	loss_total: 0.2133149951696396


  0%|          | 86/36400 [00:20<2:22:18,  4.25it/s]

86	loss_total: 0.21999728679656982


  0%|          | 87/36400 [00:20<2:23:00,  4.23it/s]

87	loss_total: 0.2132280170917511


  0%|          | 88/36400 [00:20<2:20:57,  4.29it/s]

88	loss_total: 0.21476604044437408


  0%|          | 89/36400 [00:21<2:19:52,  4.33it/s]

89	loss_total: 0.21949249505996704


  0%|          | 90/36400 [00:21<2:19:50,  4.33it/s]

90	loss_total: 0.2194749265909195


  0%|          | 91/36400 [00:21<2:18:55,  4.36it/s]

91	loss_total: 0.2128293365240097


  0%|          | 92/36400 [00:21<2:17:54,  4.39it/s]

92	loss_total: 0.21206389367580414


  0%|          | 93/36400 [00:21<2:17:32,  4.40it/s]

93	loss_total: 0.21322475373744965


  0%|          | 94/36400 [00:22<2:17:30,  4.40it/s]

94	loss_total: 0.2131960541009903


  0%|          | 95/36400 [00:22<2:17:05,  4.41it/s]

95	loss_total: 0.21250879764556885


  0%|          | 96/36400 [00:22<2:16:37,  4.43it/s]

96	loss_total: 0.21189247071743011


  0%|          | 97/36400 [00:22<2:16:28,  4.43it/s]

97	loss_total: 0.20949792861938477


  0%|          | 98/36400 [00:23<2:19:59,  4.32it/s]

98	loss_total: 0.2113179713487625


  0%|          | 99/36400 [00:23<2:18:17,  4.37it/s]

99	loss_total: 0.21425415575504303


  0%|          | 100/36400 [00:23<2:17:36,  4.40it/s]

100	loss_total: 0.21408183872699738


  0%|          | 101/36400 [00:23<2:16:25,  4.43it/s]

101	loss_total: 0.21017013490200043


  0%|          | 102/36400 [00:23<2:15:47,  4.46it/s]

102	loss_total: 0.20807045698165894


  0%|          | 103/36400 [00:24<2:17:09,  4.41it/s]

103	loss_total: 0.2166634500026703


  0%|          | 104/36400 [00:24<2:17:07,  4.41it/s]

104	loss_total: 0.21148766577243805


  0%|          | 105/36400 [00:24<2:18:26,  4.37it/s]

105	loss_total: 0.21046216785907745


  0%|          | 106/36400 [00:24<2:18:06,  4.38it/s]

106	loss_total: 0.2107342630624771


  0%|          | 107/36400 [00:25<2:20:02,  4.32it/s]

107	loss_total: 0.2181089222431183


  0%|          | 108/36400 [00:25<2:19:34,  4.33it/s]

108	loss_total: 0.21125537157058716


  0%|          | 109/36400 [00:25<2:18:39,  4.36it/s]

109	loss_total: 0.21604034304618835


  0%|          | 110/36400 [00:25<2:17:53,  4.39it/s]

110	loss_total: 0.21138013899326324


  0%|          | 111/36400 [00:26<2:17:09,  4.41it/s]

111	loss_total: 0.2111838012933731


  0%|          | 112/36400 [00:26<2:16:24,  4.43it/s]

112	loss_total: 0.204880952835083


  0%|          | 113/36400 [00:26<2:16:33,  4.43it/s]

113	loss_total: 0.21526741981506348


  0%|          | 114/36400 [00:26<2:16:02,  4.45it/s]

114	loss_total: 0.20993176102638245


  0%|          | 115/36400 [00:26<2:15:41,  4.46it/s]

115	loss_total: 0.21173585951328278


  0%|          | 116/36400 [00:27<2:17:04,  4.41it/s]

116	loss_total: 0.2141346037387848


  0%|          | 117/36400 [00:27<2:16:02,  4.45it/s]

117	loss_total: 0.21453329920768738


  0%|          | 118/36400 [00:27<2:17:18,  4.40it/s]

118	loss_total: 0.21513888239860535


  0%|          | 119/36400 [00:27<2:18:34,  4.36it/s]

119	loss_total: 0.21022915840148926


  0%|          | 120/36400 [00:28<2:17:38,  4.39it/s]

120	loss_total: 0.2093258798122406


  0%|          | 121/36400 [00:28<2:17:52,  4.39it/s]

121	loss_total: 0.2179921567440033


  0%|          | 122/36400 [00:28<2:17:15,  4.41it/s]

122	loss_total: 0.21253111958503723


  0%|          | 123/36400 [00:28<2:17:04,  4.41it/s]

123	loss_total: 0.21710935235023499


  0%|          | 124/36400 [00:28<2:16:34,  4.43it/s]

124	loss_total: 0.20948578417301178


  0%|          | 125/36400 [00:29<2:16:24,  4.43it/s]

125	loss_total: 0.21074117720127106


  0%|          | 126/36400 [00:29<2:16:28,  4.43it/s]

126	loss_total: 0.20951847732067108


  0%|          | 127/36400 [00:29<2:17:47,  4.39it/s]

127	loss_total: 0.2107563018798828


  0%|          | 128/36400 [00:29<2:27:42,  4.09it/s]

128	loss_total: 0.21752573549747467


  0%|          | 129/36400 [00:30<2:23:49,  4.20it/s]

129	loss_total: 0.21133626997470856


  0%|          | 130/36400 [00:30<2:21:50,  4.26it/s]

130	loss_total: 0.21769706904888153


  0%|          | 131/36400 [00:30<2:19:45,  4.33it/s]

131	loss_total: 0.2111247181892395


  0%|          | 132/36400 [00:30<2:19:37,  4.33it/s]

132	loss_total: 0.21057631075382233


  0%|          | 133/36400 [00:31<2:19:26,  4.33it/s]

133	loss_total: 0.21241603791713715


  0%|          | 134/36400 [00:31<2:18:40,  4.36it/s]

134	loss_total: 0.20913641154766083


  0%|          | 135/36400 [00:31<2:18:03,  4.38it/s]

135	loss_total: 0.21673887968063354


  0%|          | 136/36400 [00:31<2:16:53,  4.42it/s]

136	loss_total: 0.21030355989933014


  0%|          | 137/36400 [00:31<2:16:26,  4.43it/s]

137	loss_total: 0.20773443579673767


  0%|          | 138/36400 [00:32<2:27:57,  4.08it/s]

138	loss_total: 0.21005941927433014


  0%|          | 139/36400 [00:32<2:24:10,  4.19it/s]

139	loss_total: 0.21668171882629395


  0%|          | 140/36400 [00:32<2:21:53,  4.26it/s]

140	loss_total: 0.21712729334831238


  0%|          | 141/36400 [00:32<2:19:40,  4.33it/s]

141	loss_total: 0.2111690640449524


  0%|          | 142/36400 [00:33<2:18:44,  4.36it/s]

142	loss_total: 0.20934057235717773


  0%|          | 143/36400 [00:33<2:18:37,  4.36it/s]

143	loss_total: 0.2117396891117096


  0%|          | 144/36400 [00:33<2:18:11,  4.37it/s]

144	loss_total: 0.21643134951591492


  0%|          | 145/36400 [00:33<2:18:03,  4.38it/s]

145	loss_total: 0.2156204730272293


  0%|          | 146/36400 [00:34<2:17:14,  4.40it/s]

146	loss_total: 0.20721368491649628


  0%|          | 147/36400 [00:34<2:17:02,  4.41it/s]

147	loss_total: 0.21091872453689575


  0%|          | 148/36400 [00:34<2:19:34,  4.33it/s]

148	loss_total: 0.21171045303344727


  0%|          | 149/36400 [00:34<2:18:55,  4.35it/s]

149	loss_total: 0.212986022233963


  0%|          | 150/36400 [00:34<2:17:50,  4.38it/s]

150	loss_total: 0.2138865888118744


  0%|          | 151/36400 [00:35<2:17:16,  4.40it/s]

151	loss_total: 0.20934665203094482


  0%|          | 152/36400 [00:35<2:16:14,  4.43it/s]

152	loss_total: 0.21025590598583221


  0%|          | 153/36400 [00:35<2:16:08,  4.44it/s]

153	loss_total: 0.20939324796199799


  0%|          | 154/36400 [00:35<2:15:53,  4.45it/s]

154	loss_total: 0.21626828610897064


  0%|          | 155/36400 [00:36<2:16:15,  4.43it/s]

155	loss_total: 0.21848013997077942


  0%|          | 156/36400 [00:36<2:16:09,  4.44it/s]

156	loss_total: 0.2128901183605194


  0%|          | 157/36400 [00:36<2:16:15,  4.43it/s]

157	loss_total: 0.211566761136055


  0%|          | 158/36400 [00:36<2:28:05,  4.08it/s]

158	loss_total: 0.21074576675891876


  0%|          | 159/36400 [00:37<2:24:48,  4.17it/s]

159	loss_total: 0.21248894929885864


  0%|          | 160/36400 [00:37<2:22:32,  4.24it/s]

160	loss_total: 0.20765043795108795


  0%|          | 161/36400 [00:37<2:20:21,  4.30it/s]

161	loss_total: 0.21639667451381683


  0%|          | 162/36400 [00:37<2:18:28,  4.36it/s]

162	loss_total: 0.20980967581272125


  0%|          | 163/36400 [00:37<2:17:31,  4.39it/s]

163	loss_total: 0.2121051400899887


  0%|          | 164/36400 [00:38<2:19:12,  4.34it/s]

164	loss_total: 0.2142651528120041


  0%|          | 165/36400 [00:38<2:19:50,  4.32it/s]

165	loss_total: 0.2104870080947876


  0%|          | 166/36400 [00:38<2:19:33,  4.33it/s]

166	loss_total: 0.21236172318458557


  0%|          | 167/36400 [00:38<2:18:26,  4.36it/s]

167	loss_total: 0.2123723030090332


  0%|          | 168/36400 [00:39<2:24:52,  4.17it/s]

168	loss_total: 0.2110282927751541


  0%|          | 169/36400 [00:39<2:22:07,  4.25it/s]

169	loss_total: 0.2142583727836609


  0%|          | 170/36400 [00:39<2:20:20,  4.30it/s]

170	loss_total: 0.21522165834903717


  0%|          | 171/36400 [00:39<2:20:48,  4.29it/s]

171	loss_total: 0.21219071745872498


  0%|          | 172/36400 [00:40<2:21:07,  4.28it/s]

172	loss_total: 0.21323174238204956


  0%|          | 173/36400 [00:40<2:21:10,  4.28it/s]

173	loss_total: 0.20846106112003326


  0%|          | 174/36400 [00:40<2:20:39,  4.29it/s]

174	loss_total: 0.20922964811325073


  0%|          | 175/36400 [00:40<2:19:34,  4.33it/s]

175	loss_total: 0.21637126803398132


  0%|          | 176/36400 [00:41<2:18:27,  4.36it/s]

176	loss_total: 0.2102421671152115


  0%|          | 177/36400 [00:41<2:21:17,  4.27it/s]

177	loss_total: 0.2132132202386856


  0%|          | 178/36400 [00:41<2:33:56,  3.92it/s]

178	loss_total: 0.21260635554790497


  0%|          | 179/36400 [00:41<2:29:12,  4.05it/s]

179	loss_total: 0.2158679962158203


  0%|          | 180/36400 [00:42<2:25:11,  4.16it/s]

180	loss_total: 0.21375177800655365


  0%|          | 181/36400 [00:42<2:23:11,  4.22it/s]

181	loss_total: 0.20786526799201965


  0%|          | 182/36400 [00:42<2:20:36,  4.29it/s]

182	loss_total: 0.2139493077993393


  1%|          | 183/36400 [00:42<2:19:59,  4.31it/s]

183	loss_total: 0.2055596113204956


  1%|          | 184/36400 [00:42<2:18:11,  4.37it/s]

184	loss_total: 0.22137649357318878


  1%|          | 185/36400 [00:43<2:18:22,  4.36it/s]

185	loss_total: 0.2120995819568634


  1%|          | 186/36400 [00:43<2:17:03,  4.40it/s]

186	loss_total: 0.2121410369873047


  1%|          | 187/36400 [00:43<2:16:19,  4.43it/s]

187	loss_total: 0.2117815762758255


  1%|          | 188/36400 [00:43<2:21:04,  4.28it/s]

188	loss_total: 0.2125397026538849


  1%|          | 189/36400 [00:44<2:19:56,  4.31it/s]

189	loss_total: 0.21013294160366058


  1%|          | 190/36400 [00:44<2:18:52,  4.35it/s]

190	loss_total: 0.2133937031030655


  1%|          | 191/36400 [00:44<2:17:35,  4.39it/s]

191	loss_total: 0.21232572197914124


  1%|          | 192/36400 [00:44<2:17:35,  4.39it/s]

192	loss_total: 0.20373299717903137


  1%|          | 193/36400 [00:44<2:18:07,  4.37it/s]

193	loss_total: 0.21427197754383087


  1%|          | 194/36400 [00:45<2:18:19,  4.36it/s]

194	loss_total: 0.21336203813552856


  1%|          | 195/36400 [00:45<2:17:00,  4.40it/s]

195	loss_total: 0.21270649135112762


  1%|          | 196/36400 [00:45<2:17:54,  4.38it/s]

196	loss_total: 0.2206803858280182


  1%|          | 197/36400 [00:46<2:47:33,  3.60it/s]

197	loss_total: 0.21192529797554016


  1%|          | 198/36400 [00:46<2:39:02,  3.79it/s]

198	loss_total: 0.21869973838329315


  1%|          | 199/36400 [00:46<2:32:25,  3.96it/s]

199	loss_total: 0.21241171658039093


  1%|          | 200/36400 [00:46<2:30:40,  4.00it/s]

200	loss_total: 0.21158312261104584


  1%|          | 201/36400 [00:46<2:25:49,  4.14it/s]

201	loss_total: 0.20910722017288208


  1%|          | 202/36400 [00:47<2:25:09,  4.16it/s]

202	loss_total: 0.21499809622764587


  1%|          | 203/36400 [00:47<2:23:46,  4.20it/s]

203	loss_total: 0.2159929722547531


  1%|          | 204/36400 [00:47<2:21:21,  4.27it/s]

204	loss_total: 0.20917409658432007


  1%|          | 205/36400 [00:47<2:19:40,  4.32it/s]

205	loss_total: 0.20831052958965302


  1%|          | 206/36400 [00:48<2:18:29,  4.36it/s]

206	loss_total: 0.2119322568178177


  1%|          | 207/36400 [00:48<2:24:42,  4.17it/s]

207	loss_total: 0.21133974194526672


  1%|          | 208/36400 [00:48<2:23:59,  4.19it/s]

208	loss_total: 0.21216151118278503


  1%|          | 209/36400 [00:48<2:22:35,  4.23it/s]

209	loss_total: 0.2133399099111557


  1%|          | 210/36400 [00:49<2:20:31,  4.29it/s]

210	loss_total: 0.209523543715477


  1%|          | 211/36400 [00:49<2:20:20,  4.30it/s]

211	loss_total: 0.21360570192337036


  1%|          | 212/36400 [00:49<2:19:51,  4.31it/s]

212	loss_total: 0.21277493238449097


  1%|          | 213/36400 [00:49<2:18:36,  4.35it/s]

213	loss_total: 0.21085627377033234


  1%|          | 214/36400 [00:49<2:18:05,  4.37it/s]

214	loss_total: 0.21718475222587585


  1%|          | 215/36400 [00:50<2:17:43,  4.38it/s]

215	loss_total: 0.21600106358528137


  1%|          | 216/36400 [00:50<2:30:48,  4.00it/s]

216	loss_total: 0.21222539246082306


  1%|          | 217/36400 [00:50<2:27:03,  4.10it/s]

217	loss_total: 0.21000884473323822


  1%|          | 218/36400 [00:50<2:24:00,  4.19it/s]

218	loss_total: 0.2087743878364563


  1%|          | 219/36400 [00:51<2:22:13,  4.24it/s]

219	loss_total: 0.21657882630825043


  1%|          | 220/36400 [00:51<2:20:48,  4.28it/s]

220	loss_total: 0.21388819813728333


  1%|          | 221/36400 [00:51<2:22:09,  4.24it/s]

221	loss_total: 0.209197998046875


  1%|          | 222/36400 [00:51<2:21:28,  4.26it/s]

222	loss_total: 0.21412256360054016


  1%|          | 223/36400 [00:52<2:20:16,  4.30it/s]

223	loss_total: 0.2123550921678543


  1%|          | 224/36400 [00:52<2:37:42,  3.82it/s]

224	loss_total: 0.2130223661661148


  1%|          | 225/36400 [00:52<2:32:01,  3.97it/s]

225	loss_total: 0.2133362889289856


  1%|          | 226/36400 [00:52<2:35:54,  3.87it/s]

226	loss_total: 0.20656156539916992


  1%|          | 227/36400 [00:53<2:30:16,  4.01it/s]

227	loss_total: 0.20855431258678436


  1%|          | 228/36400 [00:53<2:27:30,  4.09it/s]

228	loss_total: 0.2078450620174408


  1%|          | 229/36400 [00:53<2:25:23,  4.15it/s]

229	loss_total: 0.2102859914302826


  1%|          | 230/36400 [00:53<2:22:06,  4.24it/s]

230	loss_total: 0.21015599370002747


  1%|          | 231/36400 [00:54<2:21:58,  4.25it/s]

231	loss_total: 0.21345005929470062


  1%|          | 232/36400 [00:54<2:22:35,  4.23it/s]

232	loss_total: 0.2145160287618637


  1%|          | 233/36400 [00:54<2:20:30,  4.29it/s]

233	loss_total: 0.21345274150371552


  1%|          | 234/36400 [00:54<2:19:12,  4.33it/s]

234	loss_total: 0.2127566933631897


  1%|          | 235/36400 [00:55<2:18:12,  4.36it/s]

235	loss_total: 0.21405400335788727


  1%|          | 236/36400 [00:55<2:17:03,  4.40it/s]

236	loss_total: 0.21223795413970947


  1%|          | 237/36400 [00:55<2:16:26,  4.42it/s]

237	loss_total: 0.2103530466556549


  1%|          | 238/36400 [00:55<2:16:00,  4.43it/s]

238	loss_total: 0.2108495980501175


  1%|          | 239/36400 [00:55<2:15:31,  4.45it/s]

239	loss_total: 0.21073228120803833


  1%|          | 240/36400 [00:56<2:16:16,  4.42it/s]

240	loss_total: 0.21696850657463074


  1%|          | 241/36400 [00:56<2:16:03,  4.43it/s]

241	loss_total: 0.20297610759735107


  1%|          | 242/36400 [00:56<2:16:06,  4.43it/s]

242	loss_total: 0.21539266407489777


  1%|          | 243/36400 [00:56<2:20:38,  4.28it/s]

243	loss_total: 0.20972739160060883


  1%|          | 244/36400 [00:57<2:18:45,  4.34it/s]

244	loss_total: 0.2137417495250702


  1%|          | 245/36400 [00:57<2:21:21,  4.26it/s]

245	loss_total: 0.2083088606595993


  1%|          | 246/36400 [00:57<2:19:21,  4.32it/s]

246	loss_total: 0.2149399071931839


  1%|          | 247/36400 [00:57<2:18:02,  4.37it/s]

247	loss_total: 0.21298442780971527


  1%|          | 248/36400 [00:58<2:17:50,  4.37it/s]

248	loss_total: 0.20908787846565247


  1%|          | 249/36400 [00:58<2:17:09,  4.39it/s]

249	loss_total: 0.21688365936279297


  1%|          | 250/36400 [00:58<2:18:07,  4.36it/s]

250	loss_total: 0.22183553874492645


  1%|          | 251/36400 [00:58<2:18:04,  4.36it/s]

251	loss_total: 0.20527009665966034


  1%|          | 252/36400 [00:58<2:17:08,  4.39it/s]

252	loss_total: 0.2068130075931549


  1%|          | 253/36400 [00:59<2:16:28,  4.41it/s]

253	loss_total: 0.20861774682998657


  1%|          | 254/36400 [00:59<2:16:21,  4.42it/s]

254	loss_total: 0.22184334695339203


  1%|          | 255/36400 [00:59<2:15:43,  4.44it/s]

255	loss_total: 0.21351538598537445


  1%|          | 256/36400 [00:59<2:15:40,  4.44it/s]

256	loss_total: 0.20768305659294128


  1%|          | 257/36400 [01:00<2:15:16,  4.45it/s]

257	loss_total: 0.21473802626132965


  1%|          | 258/36400 [01:00<2:15:20,  4.45it/s]

258	loss_total: 0.20975111424922943


  1%|          | 259/36400 [01:00<2:15:12,  4.45it/s]

259	loss_total: 0.21350890398025513


  1%|          | 260/36400 [01:00<2:15:40,  4.44it/s]

260	loss_total: 0.21283192932605743


  1%|          | 261/36400 [01:00<2:16:31,  4.41it/s]

261	loss_total: 0.21614499390125275


  1%|          | 262/36400 [01:01<2:15:40,  4.44it/s]

262	loss_total: 0.21217389404773712


  1%|          | 263/36400 [01:01<2:15:48,  4.43it/s]

263	loss_total: 0.21649062633514404


  1%|          | 264/36400 [01:01<2:20:47,  4.28it/s]

264	loss_total: 0.2140069305896759


  1%|          | 265/36400 [01:01<2:19:52,  4.31it/s]

265	loss_total: 0.21213524043560028


  1%|          | 266/36400 [01:02<2:18:22,  4.35it/s]

266	loss_total: 0.21888571977615356


  1%|          | 267/36400 [01:02<2:18:14,  4.36it/s]

267	loss_total: 0.21060419082641602


  1%|          | 268/36400 [01:02<2:17:57,  4.37it/s]

268	loss_total: 0.2101333737373352


  1%|          | 269/36400 [01:02<2:17:01,  4.39it/s]

269	loss_total: 0.2124844640493393


  1%|          | 270/36400 [01:03<2:17:06,  4.39it/s]

270	loss_total: 0.21337378025054932


  1%|          | 271/36400 [01:03<2:17:42,  4.37it/s]

271	loss_total: 0.21426165103912354


  1%|          | 272/36400 [01:03<2:17:01,  4.39it/s]

272	loss_total: 0.21455322206020355


  1%|          | 273/36400 [01:03<2:16:22,  4.41it/s]

273	loss_total: 0.21308445930480957


  1%|          | 274/36400 [01:03<2:18:24,  4.35it/s]

274	loss_total: 0.21131791174411774


  1%|          | 275/36400 [01:04<2:17:45,  4.37it/s]

275	loss_total: 0.22526083886623383


  1%|          | 276/36400 [01:04<2:17:36,  4.38it/s]

276	loss_total: 0.2113480567932129


  1%|          | 277/36400 [01:04<2:24:43,  4.16it/s]

277	loss_total: 0.21238158643245697


  1%|          | 278/36400 [01:04<2:21:45,  4.25it/s]

278	loss_total: 0.21645714342594147


  1%|          | 279/36400 [01:05<2:20:46,  4.28it/s]

279	loss_total: 0.21352913975715637


  1%|          | 280/36400 [01:05<2:19:32,  4.31it/s]

280	loss_total: 0.2154521942138672


  1%|          | 281/36400 [01:05<2:17:56,  4.36it/s]

281	loss_total: 0.21029897034168243


  1%|          | 282/36400 [01:05<2:16:49,  4.40it/s]

282	loss_total: 0.21447555720806122


  1%|          | 283/36400 [01:05<2:16:35,  4.41it/s]

283	loss_total: 0.20813894271850586


  1%|          | 284/36400 [01:06<2:15:54,  4.43it/s]

284	loss_total: 0.2128976732492447


  1%|          | 285/36400 [01:06<2:15:34,  4.44it/s]

285	loss_total: 0.20814140141010284


  1%|          | 286/36400 [01:06<2:15:10,  4.45it/s]

286	loss_total: 0.21147394180297852


  1%|          | 287/36400 [01:06<2:15:20,  4.45it/s]

287	loss_total: 0.21516291797161102


  1%|          | 288/36400 [01:07<2:15:10,  4.45it/s]

288	loss_total: 0.2150001972913742


  1%|          | 289/36400 [01:07<2:15:25,  4.44it/s]

289	loss_total: 0.21175871789455414


  1%|          | 290/36400 [01:07<2:15:14,  4.45it/s]

290	loss_total: 0.21635039150714874


  1%|          | 291/36400 [01:07<2:15:06,  4.45it/s]

291	loss_total: 0.21120256185531616


  1%|          | 292/36400 [01:08<2:16:05,  4.42it/s]

292	loss_total: 0.21948114037513733


  1%|          | 293/36400 [01:08<2:16:31,  4.41it/s]

293	loss_total: 0.20597270131111145


  1%|          | 294/36400 [01:08<2:16:33,  4.41it/s]

294	loss_total: 0.21301184594631195


  1%|          | 295/36400 [01:08<2:16:12,  4.42it/s]

295	loss_total: 0.21027176082134247


  1%|          | 296/36400 [01:08<2:15:57,  4.43it/s]

296	loss_total: 0.21437187492847443


  1%|          | 297/36400 [01:09<2:15:36,  4.44it/s]

297	loss_total: 0.2162979245185852


  1%|          | 298/36400 [01:09<2:16:37,  4.40it/s]

298	loss_total: 0.21395421028137207


  1%|          | 299/36400 [01:09<2:16:47,  4.40it/s]

299	loss_total: 0.20978935062885284


  1%|          | 300/36400 [01:09<2:16:32,  4.41it/s]

300	loss_total: 0.21596644818782806


  1%|          | 301/36400 [01:10<2:16:11,  4.42it/s]

301	loss_total: 0.21611672639846802


  1%|          | 302/36400 [01:10<2:15:56,  4.43it/s]

302	loss_total: 0.2091681808233261


  1%|          | 303/36400 [01:10<2:16:36,  4.40it/s]

303	loss_total: 0.211530864238739


  1%|          | 304/36400 [01:10<2:15:51,  4.43it/s]

304	loss_total: 0.20476223528385162


  1%|          | 305/36400 [01:10<2:15:31,  4.44it/s]

305	loss_total: 0.2118106186389923


  1%|          | 306/36400 [01:11<2:15:28,  4.44it/s]

306	loss_total: 0.21804481744766235


  1%|          | 307/36400 [01:11<2:15:36,  4.44it/s]

307	loss_total: 0.2157180905342102


  1%|          | 308/36400 [01:11<2:15:31,  4.44it/s]

308	loss_total: 0.21709595620632172


  1%|          | 309/36400 [01:11<2:15:17,  4.45it/s]

309	loss_total: 0.215057373046875


  1%|          | 310/36400 [01:12<2:15:07,  4.45it/s]

310	loss_total: 0.21019092202186584


  1%|          | 311/36400 [01:12<2:15:10,  4.45it/s]

311	loss_total: 0.21506860852241516


  1%|          | 312/36400 [01:12<2:15:30,  4.44it/s]

312	loss_total: 0.21211887896060944


  1%|          | 313/36400 [01:12<2:17:31,  4.37it/s]

313	loss_total: 0.21191160380840302


  1%|          | 314/36400 [01:13<2:17:19,  4.38it/s]

314	loss_total: 0.21788960695266724


  1%|          | 315/36400 [01:13<2:16:21,  4.41it/s]

315	loss_total: 0.20946447551250458


  1%|          | 316/36400 [01:13<2:16:13,  4.41it/s]

316	loss_total: 0.21205498278141022


  1%|          | 317/36400 [01:13<2:15:46,  4.43it/s]

317	loss_total: 0.21764878928661346


  1%|          | 318/36400 [01:13<2:15:59,  4.42it/s]

318	loss_total: 0.2149706929922104


  1%|          | 319/36400 [01:14<2:20:39,  4.28it/s]

319	loss_total: 0.2154722958803177


  1%|          | 320/36400 [01:14<2:19:35,  4.31it/s]

320	loss_total: 0.21468055248260498


  1%|          | 321/36400 [01:14<2:19:11,  4.32it/s]

321	loss_total: 0.21474406123161316


  1%|          | 322/36400 [01:14<2:18:11,  4.35it/s]

322	loss_total: 0.21631240844726562


  1%|          | 323/36400 [01:15<2:17:54,  4.36it/s]

323	loss_total: 0.2130049467086792


  1%|          | 324/36400 [01:15<2:17:18,  4.38it/s]

324	loss_total: 0.21978497505187988


  1%|          | 325/36400 [01:15<2:16:09,  4.42it/s]

325	loss_total: 0.21252891421318054


  1%|          | 326/36400 [01:15<2:16:04,  4.42it/s]

326	loss_total: 0.21355189383029938


  1%|          | 327/36400 [01:15<2:17:05,  4.39it/s]

327	loss_total: 0.212412029504776


  1%|          | 328/36400 [01:16<2:16:43,  4.40it/s]

328	loss_total: 0.21304713189601898


  1%|          | 329/36400 [01:16<2:17:43,  4.37it/s]

329	loss_total: 0.214911550283432


  1%|          | 330/36400 [01:16<2:26:21,  4.11it/s]

330	loss_total: 0.2095024436712265


  1%|          | 331/36400 [01:16<2:23:14,  4.20it/s]

331	loss_total: 0.21362465620040894


  1%|          | 332/36400 [01:17<2:20:56,  4.27it/s]

332	loss_total: 0.20891861617565155


  1%|          | 333/36400 [01:17<2:18:52,  4.33it/s]

333	loss_total: 0.2155054807662964


  1%|          | 334/36400 [01:17<2:17:57,  4.36it/s]

334	loss_total: 0.2109174132347107


  1%|          | 335/36400 [01:17<2:16:47,  4.39it/s]

335	loss_total: 0.21515215933322906


  1%|          | 336/36400 [01:18<2:16:29,  4.40it/s]

336	loss_total: 0.21614088118076324


  1%|          | 337/36400 [01:18<2:16:39,  4.40it/s]

337	loss_total: 0.2172158658504486


  1%|          | 338/36400 [01:18<2:16:02,  4.42it/s]

338	loss_total: 0.21807537972927094


  1%|          | 339/36400 [01:18<2:15:58,  4.42it/s]

339	loss_total: 0.21861858665943146


  1%|          | 340/36400 [01:18<2:15:34,  4.43it/s]

340	loss_total: 0.2164120227098465


  1%|          | 341/36400 [01:19<2:15:30,  4.44it/s]

341	loss_total: 0.21096408367156982


  1%|          | 342/36400 [01:19<2:15:38,  4.43it/s]

342	loss_total: 0.215642049908638


  1%|          | 343/36400 [01:19<2:26:09,  4.11it/s]

343	loss_total: 0.21486778557300568


  1%|          | 344/36400 [01:19<2:23:24,  4.19it/s]

344	loss_total: 0.21413184702396393


  1%|          | 345/36400 [01:20<2:21:50,  4.24it/s]

345	loss_total: 0.2137601524591446


  1%|          | 346/36400 [01:20<2:23:00,  4.20it/s]

346	loss_total: 0.21171167492866516


  1%|          | 347/36400 [01:20<2:20:39,  4.27it/s]

347	loss_total: 0.2110518217086792


  1%|          | 348/36400 [01:20<2:19:01,  4.32it/s]

348	loss_total: 0.213357612490654


  1%|          | 349/36400 [01:21<2:19:18,  4.31it/s]

349	loss_total: 0.21301202476024628


  1%|          | 350/36400 [01:21<2:19:21,  4.31it/s]

350	loss_total: 0.2137451469898224


  1%|          | 351/36400 [01:21<2:19:10,  4.32it/s]

351	loss_total: 0.21184448897838593


  1%|          | 352/36400 [01:21<2:18:47,  4.33it/s]

352	loss_total: 0.2095434069633484


  1%|          | 353/36400 [01:21<2:18:14,  4.35it/s]

353	loss_total: 0.21660283207893372


  1%|          | 354/36400 [01:22<2:26:41,  4.10it/s]

354	loss_total: 0.21447347104549408


  1%|          | 355/36400 [01:22<2:24:07,  4.17it/s]

355	loss_total: 0.20930011570453644


  1%|          | 356/36400 [01:22<2:22:09,  4.23it/s]

356	loss_total: 0.21505451202392578


  1%|          | 357/36400 [01:22<2:20:19,  4.28it/s]

357	loss_total: 0.21212221682071686


  1%|          | 358/36400 [01:23<2:18:46,  4.33it/s]

358	loss_total: 0.21039721369743347


  1%|          | 359/36400 [01:23<2:17:41,  4.36it/s]

359	loss_total: 0.21197326481342316


  1%|          | 360/36400 [01:23<2:20:20,  4.28it/s]

360	loss_total: 0.21663916110992432


  1%|          | 361/36400 [01:23<2:19:16,  4.31it/s]

361	loss_total: 0.21371424198150635


  1%|          | 362/36400 [01:24<2:19:06,  4.32it/s]

362	loss_total: 0.2095167487859726


  1%|          | 363/36400 [01:24<2:19:32,  4.30it/s]

363	loss_total: 0.2111210972070694


  1%|          | 363/36400 [01:24<2:19:56,  4.29it/s]

364	loss_total: 0.21032653748989105





ValueError: Expected 3D tensor with dimensions (batch, channel, frames). Found: torch.Size([40960])