In [None]:
## Further investigation found the slow mnst generator is caused by slow 1d indexing on CPU
## Switch to 2D indexing plus list comprehesion that removed those for loops speed up the generator by 6 times. 
## the data transfer time from cpu to GPU takes only a small percentage of the run time. 

In [1]:
from concurrent.futures import (ALL_COMPLETED, ThreadPoolExecutor,
                                as_completed, wait)
from time import time

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

In [2]:
embedding_tensor_path = "../datasets/mnst_train_dinov2_small.pt"
label_tensor_path = "../datasets/mnst_train_labels.pt"
embedding = torch.load(embedding_tensor_path) 
labels = torch.load(label_tensor_path)

In [5]:
n = len(embedding)
batch_size = 128
max_bag_length = 40

In [108]:
def get_one_batch():
    random_bag_lengths = np.clip(
                    np.random.poisson(20, size=(batch_size)).astype(int),
                    1,
                    max_bag_length,
                )
    attention_mask = torch.tensor([[1]*l + [0]*(max_bag_length - l) for l in random_bag_lengths], dtype=(torch.float32))
    batched_random_indices = np.random.randint(n, size=(batch_size, max_bag_length))
    input_tensor = embedding[batched_random_indices]
    label_tensor = (torch.sum((labels[batched_random_indices] == 9) * attention_mask, axis=1) >= 4).to(torch.float32)
    return input_tensor, label_tensor, attention_mask

def get_one_batch2():
    random_bag_lengths = np.clip(
                    np.random.poisson(20, size=(batch_size)).astype(int),
                    1,
                    max_bag_length,
                )
    attention_mask = torch.zeros((batch_size, max_bag_length), dtype=(torch.float32))
    for i, l_ in enumerate(random_bag_lengths):
        attention_mask[i, :l_] = 1
    
    batched_random_indices = np.random.randint(n, size=(batch_size, max_bag_length))
    input_tensor = torch.zeros(
                (batch_size, max_bag_length, 384),
                dtype=(torch.float32),
            )
    for i in range(batch_size):
        input_tensor[i] = embedding[batched_random_indices[i]]
    label_tensor = torch.zeros((batch_size), dtype=(torch.float32))
    for i in range(batch_size):
        label_tensor[i] = int(
            torch.sum(labels[batched_random_indices[i]][: random_bag_lengths[i]] == 9)
            >= 4
        )
    return input_tensor, label_tensor, attention_mask

In [109]:
## compute only on CPU 
output_tensor_cpu = torch.zeros((batch_size, max_bag_length, 384), dtype=(torch.float32))
t1 = time()
get_batch_times = []
for i in range(500):
    t1a = time()
    tem, label, mask = get_one_batch()
    t1b = time()
    get_batch_times.append(t1b - t1a)
    output_tensor_cpu = tem + 1
t2 = time()
t2 - t1, np.sum(get_batch_times)

(1.2856097221374512, 1.0330467224121094)

In [111]:
## compute only on CPU 
output_tensor_cpu = torch.zeros((batch_size, max_bag_length, 384), dtype=(torch.float32))
t1 = time()
get_batch_times = []
for i in range(500):
    t1a = time()
    tem, label, mask = get_one_batch2()
    t1b = time()
    get_batch_times.append(t1b - t1a)
    output_tensor_cpu = tem + 1
t2 = time()
t2 - t1, np.sum(get_batch_times)

(6.981238842010498, 6.7260212898254395)

In [112]:
## compute only on GPU 
output_tensor_gpu = torch.zeros((batch_size, max_bag_length, 384), dtype=(torch.float32)).cuda()
t1 = time()
for i in range(500):
    t1a = time()
    tem, label, mask = get_one_batch()
    tem = tem.cuda(non_blocking=True)
    output_tensor_gpu = tem + 1
    t1b = time()
t2 = time()
t2 - t1

1.4963743686676025

In [96]:
random_bag_lengths = np.clip(
                    np.random.poisson(20, size=(batch_size)).astype(int),
                    1,
                    max_bag_length,
                )
attention_mask = torch.tensor([[1]*l + [0]*(max_bag_length - l) for l in random_bag_lengths], dtype=(torch.float32))

label_tensor = torch.zeros((batch_size), dtype=(torch.float32))
for i in range(batch_size):
    label_tensor[i] = int(
        torch.sum(labels[batched_random_indices[i]][: random_bag_lengths[i]] == 9)
        >= 4
    )

In [100]:
label_tensor2 = (torch.sum((labels[batched_random_indices] == 9) * attention_mask, axis=1) >= 4).to(torch.float32)

In [103]:
torch.sum(label_tensor != label_tensor2)

tensor(0)