In [1]:
!pip install openai-whisper
!pip install whisperx
!pip install TorchCRF

Collecting openai-whisper
  Downloading openai_whisper-20250625.tar.gz (803 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/803.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/803.2 kB[0m [31m4.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m798.7/803.2 kB[0m [31m13.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.2/803.2 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->openai-whisper)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->openai-w

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchaudio
from whisper import load_model
import whisper
import numpy as np
import whisperx

Encoder_DIM = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
whisper_model = load_model(name='base.en', download_root='./').to(device)
whisper_model.eval()

100%|███████████████████████████████████████| 139M/139M [00:13<00:00, 10.9MiB/s]


Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-5): 6 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=512, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=False)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (out): Linear(in_features=512, out_features=512, bias=True)
        )
        (attn_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (mlp_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((512,), eps=1e-05,

In [3]:
fcount = 0
def build_truncation_labels(word_segments, num_frames, frame_stride=0.02):
    global fcount
    labels = np.zeros(num_frames, dtype=np.float32)
    end_time = 0
    start_time = 0
    for word in word_segments:
        start_time = word['start']
        end_time = word['end']
        token_end_frame = int(round(end_time / frame_stride))
        token_start_frame = int(round(start_time / frame_stride))
        alpha = 1 / (token_end_frame - token_start_frame + 1)
        labels[token_start_frame:token_end_frame+1] += alpha

    return torch.tensor(labels, dtype=torch.float32)

In [None]:
print(whisper.__file__)

/usr/local/lib/python3.11/dist-packages/whisper/__init__.py


In [4]:
import random
dataset = torchaudio.datasets.LIBRISPEECH('./data', url="train-clean-100", download=True)
align_model, metadata = whisperx.load_align_model(language_code='en', device=device)
labels = []
encoder_outs = []
count = 0

for i in range(len(dataset)):
    if(count >= 150000):
      break
    waveform, sr, transcript, *_ = dataset[i]
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    res = {
            "segments": [{
                "start": 0,
                "end": waveform.shape[1] / sr,
                "text": transcript
            }]
        }
    result = whisperx.align(res["segments"], align_model, metadata, waveform[0], device, return_char_alignments=False)
    word_segments = result["word_segments"]
    label = build_truncation_labels(word_segments, 3000)
    length = waveform.shape[1] / 16000
    flag = 0
    k = -1
    for word in word_segments:
        k += 1
        if(k%5 != 0):
          continue
        chunk_len = 1 if random.random() < 0.2 else 2
        start = word['end']
        end = start + chunk_len
        chunk = waveform[:, int(start * 16000):int(end * 16000)]
        if end > length:
            flag = 1
        chunk_padded = whisper.pad_or_trim(chunk)
        mel = whisper.log_mel_spectrogram(chunk_padded).to(device)
        with torch.no_grad():
            encoder_out = whisper_model.encoder(mel)
        encoder_out = encoder_out.squeeze(0)[:100]
        token_start = int(round(start / 0.02))
        token_end = token_start + chunk_len * 50
        label_chunk = label[token_start:token_end]
        pad_label = torch.zeros(encoder_out.shape[0] - label_chunk.shape[0])
        label_chunk = torch.cat((label_chunk, pad_label))
        count += 1
        labels.append(label_chunk.cpu())
        encoder_outs.append(encoder_out.cpu())
        if flag == 1:
            break

100%|██████████| 5.95G/5.95G [03:36<00:00, 29.4MB/s]
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
100%|██████████| 360M/360M [00:01<00:00, 317MB/s]


In [5]:
dev_dataset = torchaudio.datasets.LIBRISPEECH('./data', url="dev-clean", download=True)
align_model, metadata = whisperx.load_align_model(language_code='en', device=device)
dev_labels = []
dev_encoder_outs = []
count = 0

for i in range(len(dev_dataset)):
    if(count >= 8000):
      break
    waveform, sr, transcript, *_ = dev_dataset[i]
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    res = {
            "segments": [{
                "start": 0,
                "end": waveform.shape[1] / sr,
                "text": transcript
            }]
        }
    result = whisperx.align(res["segments"], align_model, metadata, waveform[0], device, return_char_alignments=False)
    word_segments = result["word_segments"]
    label = build_truncation_labels(word_segments, 3000)
    length = waveform.shape[1] / 16000
    flag = 0
    k = -1
    for word in word_segments:
        k += 1
        if(k%5 != 0):
          continue
        chunk_len = 1 if random.random() < 0.2 else 2
        start = word['end']
        end = start + chunk_len
        chunk = waveform[:, int(start * 16000):int(end * 16000)]
        if end > length:
            flag = 1
        chunk_padded = whisper.pad_or_trim(chunk)
        mel = whisper.log_mel_spectrogram(chunk_padded).to(device)
        with torch.no_grad():
            encoder_out = whisper_model.encoder(mel)
        encoder_out = encoder_out.squeeze(0)[:100]
        token_start = int(round(start / 0.02))
        token_end = token_start + chunk_len * 50
        label_chunk = label[token_start:token_end]
        pad_label = torch.zeros(encoder_out.shape[0] - label_chunk.shape[0])
        label_chunk = torch.cat((label_chunk, pad_label))
        count += 1
        dev_labels.append(label_chunk.cpu())
        dev_encoder_outs.append(encoder_out.cpu())
        if flag == 1:
            break

100%|██████████| 322M/322M [00:17<00:00, 19.8MB/s]


In [6]:
test_dataset = torchaudio.datasets.LIBRISPEECH('./data', url="test-clean", download=True)
align_model, metadata = whisperx.load_align_model(language_code='en', device=device)
test_labels = []
test_encoder_outs = []
count = 0

for i in range(len(test_dataset)):
    if(count >= 8000):
      break
    waveform, sr, transcript, *_ = test_dataset[i]
    if sr != 16000:
        waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
    res = {
            "segments": [{
                "start": 0,
                "end": waveform.shape[1] / sr,
                "text": transcript
            }]
        }
    result = whisperx.align(res["segments"], align_model, metadata, waveform[0], device, return_char_alignments=False)
    word_segments = result["word_segments"]
    label = build_truncation_labels(word_segments, 3000)
    length = waveform.shape[1] / 16000
    flag = 0
    k = 0
    for word in word_segments:
        k += 1
        if(k%5 != 0):
          continue
        chunk_len = 1 if random.random() < 0.2 else 2
        start = word['end']
        end = start + chunk_len
        chunk = waveform[:, int(start * 16000):int(end * 16000)]
        if end > length:
            flag = 1
        chunk_padded = whisper.pad_or_trim(chunk)
        mel = whisper.log_mel_spectrogram(chunk_padded).to(device)
        with torch.no_grad():
            encoder_out = whisper_model.encoder(mel)
        encoder_out = encoder_out.squeeze(0)[:100]
        token_start = int(round(start / 0.02))
        token_end = token_start + chunk_len * 50
        label_chunk = label[token_start:token_end]
        pad_label = torch.zeros(encoder_out.shape[0] - label_chunk.shape[0])
        label_chunk = torch.cat((label_chunk, pad_label))
        count += 1
        test_labels.append(label_chunk.cpu())
        test_encoder_outs.append(encoder_out.cpu())
        if flag == 1:
            break

100%|██████████| 331M/331M [00:10<00:00, 33.5MB/s]


In [7]:
class LSTMDataset(Dataset):
    def __init__(self, encoder_outs, labels, frame_stride=0.02):
        self.encoder_outs = encoder_outs
        self.frame_stride = frame_stride
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Get encoder output
        encoder_out = self.encoder_outs[idx]
        encoder_out = encoder_out.to(device)
        # Get label
        label = self.labels[idx]
        label = label.to(device)
        return encoder_out, label

In [8]:
def collate_fn(batch):
    encs, labs = zip(*batch)
    encs = torch.stack(encs)  # [B, T, D]
    labs = torch.stack(labs)  # [B, T]
    return encs, labs

In [9]:
train_dataset = LSTMDataset(encoder_outs, labels)
dev_dataset = LSTMDataset(dev_encoder_outs, dev_labels)
test_dataset = LSTMDataset(test_encoder_outs, test_labels)
train_size = len(train_dataset)
dev_size = len(dev_dataset)
test_size = len(test_dataset)
print(f"train_size: {train_size}, dev_size: {dev_size}, test_size: {test_size}")

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
dev_dataloader = DataLoader(dev_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

train_size: 150001, dev_size: 8001, test_size: 8004


In [10]:
class LSTMTruncationDetector(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, num_layers=2, bidirectional=True, dropout=0.3):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, bidirectional=bidirectional, dropout=dropout)
        self.out_proj = nn.Linear(hidden_dim * (2 if bidirectional else 1), 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):  # x: [B, T, D]
        lstm_out, _ = self.lstm(x)          # [B, T, H]
        lstm_out = self.dropout(lstm_out)
        logits = self.out_proj(lstm_out)    # [B, T, 1]
        alphas = torch.sigmoid(logits).squeeze(-1)  # [B, T]
        return alphas

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CIFTimeLoss(nn.Module):
    def __init__(self, lambda_time: float = 1.0, lambda_count: float = 1.0, lambda_blank: float = 1.0, eps: float = 1e-6):
        super().__init__()
        self.lambda_time = lambda_time
        self.lambda_count = lambda_count
        self.lambda_blank = lambda_blank
        self.eps = eps
        self.frame_reg_loss = nn.SmoothL1Loss(reduction='mean', beta=3.0)
        self.count_loss_fn = nn.SmoothL1Loss(reduction='mean', beta=0.1)

    def forward(self, alpha: torch.Tensor, true_counts: torch.Tensor, true_frames_list: list, labs: torch.Tensor, epoch: int):
        device = alpha.device
        B, T = alpha.shape

        sum_alpha = alpha.sum(dim=1)
        l_count = self.count_loss_fn(sum_alpha, true_counts)

        A = torch.cumsum(alpha, dim=1)

        l_time_sum = torch.tensor(0.0, device=device)
        valid_samples = 0
        beta = 2.0 + epoch * 0.5
        #ta = 1 / epoch

        blank_mask = (labs == 0).float()
        l_blank = (alpha * blank_mask).sum() # sum or mean

        for i in range(B):
            U_i = int(true_counts[i].item())
            if U_i <= 0:
                continue

            A_i = A[i]
            thresholds = torch.arange(1, U_i+1, device=device, dtype=A_i.dtype)

            diff = (A_i.unsqueeze(0) - thresholds.unsqueeze(1)).abs()
            penalty = (A_i.unsqueeze(0) - thresholds.unsqueeze(1)).clamp(min=0)
            # f = diff + 15 * penalty
            #weights = F.gumbel_softmax(diff, tau=ta, hard=True)
            weights = F.softmax(-beta * diff - 20 * penalty, dim=1)

            t_idx = torch.arange(T, device=device, dtype=torch.float32)
            pred_frames = (weights * t_idx).sum(dim=1)

            true_frames = true_frames_list[i].to(device).float()
            l_time_i = self.frame_reg_loss(pred_frames, true_frames)

            l_time_sum += l_time_i
            valid_samples += 1

        l_time = (l_time_sum / valid_samples) if valid_samples>0 else torch.tensor(0.0, device=device)
        loss   = self.lambda_count * l_count + self.lambda_time * l_time + self.lambda_blank * l_blank
        return loss, l_count, l_time, l_blank

In [12]:
model = LSTMTruncationDetector(input_dim=Encoder_DIM).to(device)
criterion = CIFTimeLoss(lambda_count=2.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [13]:
def get_frame_lists(lab):
    cumsum = lab.cumsum(dim=0)
    total_tokens = int(lab.sum().item())
    fire_indices = []
    a = 1e-4
    for k in range(1, total_tokens + 1):
        idx = torch.searchsorted(cumsum, torch.tensor(k-a, device=cumsum.device)).item()
        fire_indices.append(idx)
    return torch.tensor(fire_indices, dtype=torch.long)

In [14]:
least_dev_loss = 1e9

In [15]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
num_epochs = 20
for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_loss = 0.0
    dev_loss = 0.0
    dev_loss_count = 0.0
    dev_loss_frame = 0.0
    dev_loss_blank = 0.0
    all_diffs = []

    for enc, lab in train_dataloader:
        batch_size, seq_len, _ = enc.size()
        alphas = model(enc)
        true_counts = lab.sum(dim=-1).float()
        true_frames_list = [get_frame_lists(lab[i]) for i in range(batch_size)]
        loss, _, _, _ = criterion(alphas, true_counts, true_frames_list, lab, epoch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    model.eval()
    for enc, lab in test_dataloader:
        with torch.no_grad():
            batch_size, seq_len, _ = enc.size()
            alphas = model(enc)
            true_counts = lab.sum(dim=-1).float()
            true_frames_list = [get_frame_lists(lab[i]) for i in range(batch_size)]
            pred_counts = alphas.sum(dim=-1).float()
            diff = (pred_counts - true_counts).abs().cpu().numpy().tolist()
            loss, loss_count, loss_frame, loss_blank = criterion(alphas, true_counts, true_frames_list, lab, epoch)
            dev_loss += loss.item()
            dev_loss_count += loss_count.item()
            dev_loss_frame += loss_frame.item()
            dev_loss_blank += loss_blank.item()
            all_diffs.extend(diff)
    epoch_loss /= len(train_dataloader)
    dev_loss /= len(test_dataloader)
    dev_loss_count /= len(test_dataloader)
    dev_loss_frame /= len(test_dataloader)
    dev_loss_blank /= len(test_dataloader)
    avg_diff = np.mean(all_diffs)
    print(f"Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.4f}, Avg Diff: {avg_diff:.4f}, Dev Loss: {dev_loss:.4f}")
    print(f"  Count Loss: {dev_loss_count:.4f}, Frame Loss: {dev_loss_frame:.4f}, Blank Loss: {dev_loss_blank:.4f}")
    if dev_loss < least_dev_loss:
        least_dev_loss = dev_loss
        torch.save(model.state_dict(), f"model.pth")
        print("Model saved!")

Epoch 1/20, Loss: 17.7949, Avg Diff: 0.3481, Dev Loss: 5.0657
  Count Loss: 0.3035, Frame Loss: 4.0894, Blank Loss: 0.3692
Model saved!
Epoch 2/20, Loss: 4.5585, Avg Diff: 0.3061, Dev Loss: 4.4485
  Count Loss: 0.2626, Frame Loss: 3.3451, Blank Loss: 0.5782
Model saved!
Epoch 3/20, Loss: 4.0870, Avg Diff: 0.3029, Dev Loss: 3.8291
  Count Loss: 0.2593, Frame Loss: 2.9411, Blank Loss: 0.3696
Model saved!
Epoch 4/20, Loss: 3.6911, Avg Diff: 0.2926, Dev Loss: 3.5032
  Count Loss: 0.2490, Frame Loss: 2.6726, Blank Loss: 0.3326
Model saved!
Epoch 5/20, Loss: 3.5296, Avg Diff: 0.3053, Dev Loss: 3.6057
  Count Loss: 0.2613, Frame Loss: 2.8565, Blank Loss: 0.2267
Epoch 6/20, Loss: 3.6987, Avg Diff: 0.2950, Dev Loss: 3.5349
  Count Loss: 0.2514, Frame Loss: 2.5099, Blank Loss: 0.5221
Epoch 7/20, Loss: 3.3627, Avg Diff: 0.2927, Dev Loss: 3.5703
  Count Loss: 0.2499, Frame Loss: 2.5984, Blank Loss: 0.4720
Epoch 8/20, Loss: 3.3213, Avg Diff: 0.2996, Dev Loss: 3.5351
  Count Loss: 0.2564, Frame Loss

KeyboardInterrupt: 

In [16]:
def get_token_frame_intervals(alphas: torch.Tensor, threshold: float = 1.0):
    B, T = alphas.size()
    token_counts = []
    intervals = []

    for b in range(B):
        integrate = 0.0
        prev_fire = -1
        samps = alphas[b].tolist()

        this_intervals = []
        for t, a in enumerate(samps):
            integrate += a
            if integrate >= threshold:
                start = prev_fire + 1
                end   = t
                this_intervals.append((start, end))
                prev_fire = t
                integrate -= threshold

        token_counts.append(len(this_intervals))
        intervals.append(this_intervals)

    return token_counts, intervals

In [17]:
def get_token_frame_intervals_single(alphas: torch.Tensor, threshold: float = 1.0):
    integrate = 0.0
    prev_fire = -1
    samps = alphas.tolist()

    intervals = []
    for t, a in enumerate(samps):
        integrate += a
        if integrate >= threshold:
            start = prev_fire + 1
            end = t
            intervals.append((start, end))
            prev_fire = t
            integrate -= threshold

    token_count = len(intervals)
    return token_count, intervals

In [None]:
# def resize(alphas: torch.Tensor,
#            target_lengths: torch.Tensor,
#            noise: float = 0.0,
#            threshold: float = 1.0,
#            max_iter: int = 20):
#     device = alphas.device
#     B, T = alphas.size()

#     orig_sums = alphas.sum(dim=-1)

#     num = float(target_lengths)
#     if noise > 0:
#         num = num + noise * torch.rand_like(num)

#     scale = (num / (orig_sums + 1e-8)).unsqueeze(1)
#     resized = alphas * scale

#     for _ in range(max_iter):
#         mask_exceed = resized > threshold
#         if not mask_exceed.any():
#             break
#         for b in torch.unique(mask_exceed.nonzero()[:,0]):
#             row = resized[b]
#             mask = row.ne(0).float()
#             mean_val = 0.5 * row.sum() / (mask.sum() + 1e-8)
#             resized[b] = row * 0.5 + mean_val * mask

#     return resized, orig_sums

In [18]:
def resize_single(alphas: torch.Tensor,
              target_length,
              noise: float = 0.0,
              threshold: float = 1.0,
              max_iter: int = 20):

    device = alphas.device
    orig_sum = alphas.sum()

    if isinstance(target_length, torch.Tensor):
        num = target_length.to(device)
    else:
        num = torch.tensor(float(target_length), device=device)

    if noise > 0:
        num = num + noise * torch.rand((), device=device)

    scale = num / (orig_sum + 1e-8)
    resized = alphas * scale

    for _ in range(max_iter):
        mask_exceed = resized > threshold
        if not mask_exceed.any():
            break

        row = resized
        mask = row.ne(0).float()
        mean_val = 0.5 * row.sum() / (mask.sum() + 1e-8)
        resized = row * 0.5 + mean_val * mask

    return resized, orig_sum

In [19]:
def truncate_alphas(alphas: torch.Tensor, threshold: float = 1.0) -> torch.Tensor:
    B, T = alphas.shape
    device = alphas.device

    truncated = alphas.clone()

    for b in range(B):
        row = alphas[b]
        A = torch.cumsum(row, dim=0)

        total = A[-1].item()
        K = torch.floor(torch.tensor(total / threshold, device=device)) * threshold
        if K < threshold:
            left = 0
            continue
        left = total - K
        idx = (A >= K).nonzero(as_tuple=False)
        t_last = idx[0, 0].item()

        if t_last + 1 < T:
            truncated[b, t_last+1 :] = 0.0

    return truncated, left

In [20]:
def Judge_truncate(pred_interval, left, alpha):
    if len(pred_interval) == 0:
        return False, alpha
    last_interval = pred_interval[-1]
    if last_interval[1] == 49 and len(pred_interval) != 1:
        alpha[last_interval[0]:] = 0.0
        return True, alpha
    if last_interval[1] <= 45 and left >= 0.09:
        return True, alpha
    return False, alpha

In [21]:
def train_frame(alpha, count):
    U_i = int(count)
    if U_i <= 0:
        return []
    A_i = torch.cumsum(alpha, dim=0)
    thresholds = torch.arange(1, U_i+1, device=device, dtype=A_i.dtype)

    diff = (A_i.unsqueeze(0) - thresholds.unsqueeze(1)).abs()
    penalty = (A_i.unsqueeze(0) - thresholds.unsqueeze(1)).clamp(min=0)
    f = diff
    weights = F.softmax(-10 * f - 10 * penalty, dim=1)

    t_idx = torch.arange(50, device=device, dtype=torch.float32)
    pred_frames = (weights * t_idx).sum(dim=1)
    return pred_frames

In [22]:
model.load_state_dict(torch.load("model.pth"))

<All keys matched successfully>

In [23]:
model.eval()
ct1 = 0
ct2 = 0
diffs = []
criterion = CIFTimeLoss()
for enc, lab in test_dataloader:
    with torch.no_grad():
        batch_size, seq_len, _ = enc.size()
        alphas = model(enc)
        true_counts = lab.sum(dim=-1).float()
        for b in range(enc.size(0)):
            true_count = true_counts[b]
            a_sum = alphas[b].sum()
            train_frames = train_frame(alphas[b], true_count)
            true_frames = get_frame_lists(lab[b])
            pred_token_count, pred_interval = get_token_frame_intervals_single(alphas[b], threshold=0.99)
            # if true_count != 1:
            #   continue
            print(f"  true_counts: {true_counts[b]}, a_sum: {a_sum}")
            print(f"  train_frames: {train_frames}")
            print(f"  true_frames: {true_frames}")
            print(f"  pred_interval: {pred_interval}")
            print(alphas[b][0:10])


RuntimeError: The size of tensor a (100) must match the size of tensor b (50) at non-singleton dimension 1

In [None]:
np.max(diffs)