In [2]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from bio_if.modules.mlp import MLPBlock
from tqdm import tqdm

In [3]:
# load mnist dataset as vectors
gen = torch.Generator().manual_seed(42)
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('data', train=False, download=True, transform=transform)
# grab 10% as in Bae et al. 2020
train_data, _ = torch.utils.data.random_split(train_data, [6000, 54000], generator=gen)
test_data, _ = torch.utils.data.random_split(test_data, [1000, 9000], generator=gen)
train_data, val_data = torch.utils.data.random_split(train_data, [5000, 1000], generator=gen)

In [4]:
# select 20 random samples from the training set
idx = torch.randperm(len(train_data), generator=gen)[:20]

# create 20 datasets excluding these samples removed
datasets_removed = [
    torch.utils.data.Subset(train_data, torch.cat([torch.arange(i), torch.arange(i+1, len(train_data))]))
    for i in range(20)
] 
assert all(len(d) == len(train_data) - 1 for d in datasets_removed)

In [5]:
def generate_model():
    D = 128
    return nn.Sequential(
        MLPBlock(784, D),
        MLPBlock(D, D),
        MLPBlock(D, 10, use_relu=False),
    )

In [20]:
# train a model on the original dataset
DEVICE = "cuda:7"
EPOCHS = 1000
model = generate_model().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-2)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False)

val_loss, train_loss, val_acc = 0, 0, 0
for epoch in range(EPOCHS):
    model.train()
    for x, y in tqdm(train_loader, desc=f'Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, val_acc={acc:.4f}', position=0, leave=True):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = nn.functional.cross_entropy(y_pred, y)
        train_loss = loss.item()
        loss.backward()
        optimizer.step()
    model.eval()
    correct = 0
    val_loss = 0
    for x, y in val_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        y_pred = model(x)
        val_loss += nn.functional.cross_entropy(y_pred, y.to(DEVICE), reduction='sum').item()
        correct += (y_pred.argmax(dim=1) == y).sum().item()
    val_loss /= len(val_data)
    acc = correct / len(val_data)

Epoch 0: train_loss=0.0000, val_loss=0.0000, val_acc=0.0980:  10%|█         | 4/40 [00:00<00:01, 32.04it/s]

Epoch 0: train_loss=0.0000, val_loss=0.0000, val_acc=0.0980: 100%|██████████| 40/40 [00:01<00:00, 39.61it/s]
Epoch 1: train_loss=2.2788, val_loss=2.3028, val_acc=0.1240: 100%|██████████| 40/40 [00:00<00:00, 48.48it/s]
Epoch 2: train_loss=2.3188, val_loss=2.3018, val_acc=0.1300: 100%|██████████| 40/40 [00:00<00:00, 48.44it/s]
Epoch 3: train_loss=2.3180, val_loss=2.3008, val_acc=0.1390: 100%|██████████| 40/40 [00:00<00:00, 49.05it/s]
Epoch 4: train_loss=2.3041, val_loss=2.2998, val_acc=0.1500: 100%|██████████| 40/40 [00:00<00:00, 48.85it/s]
Epoch 5: train_loss=2.3068, val_loss=2.2989, val_acc=0.1630: 100%|██████████| 40/40 [00:00<00:00, 48.42it/s]
Epoch 6: train_loss=2.2906, val_loss=2.2979, val_acc=0.1720: 100%|██████████| 40/40 [00:00<00:00, 48.40it/s]
Epoch 7: train_loss=2.3232, val_loss=2.2969, val_acc=0.1830: 100%|██████████| 40/40 [00:00<00:00, 48.46it/s]
Epoch 8: train_loss=2.3068, val_loss=2.2959, val_acc=0.1920: 100%|██████████| 40/40 [00:00<00:00, 48.58it/s]
Epoch 9: train_loss

In [6]:
from bio_if.modules.influence import influence

In [22]:
def dataset_to_list(dataset):
    return [(x, torch.tensor(y)) for x, y in dataset]

In [23]:
influence_scores = influence(model, list(model), dataset_to_list(test_data), dataset_to_list(train_data), dataset_to_list([train_data[i] for i in idx]), DEVICE, torch.nn.CrossEntropyLoss(), aggregate_query_grads=True)

Computing EKFAC factors and pseudo gradients


100%|██████████| 5000/5000 [00:20<00:00, 240.30it/s]


Computing search gradients


100%|██████████| 20/20 [00:00<00:00, 328.38it/s]


Computing iHVP


100%|██████████| 1/1 [00:00<00:00, 201.97it/s]
100%|██████████| 1/1 [00:00<00:00, 291.23it/s]
100%|██████████| 1/1 [00:00<00:00, 283.05it/s]
100%|██████████| 1/1 [00:00<00:00, 287.03it/s]
100%|██████████| 1/1 [00:00<00:00, 297.17it/s]
100%|██████████| 1/1 [00:00<00:00, 287.22it/s]
100%|██████████| 1/1 [00:00<00:00, 291.76it/s]
100%|██████████| 1/1 [00:00<00:00, 274.32it/s]
100%|██████████| 1/1 [00:00<00:00, 262.28it/s]
100%|██████████| 1/1 [00:00<00:00, 257.70it/s]
100%|██████████| 1/1 [00:00<00:00, 296.42it/s]
100%|██████████| 1/1 [00:00<00:00, 259.66it/s]
100%|██████████| 1/1 [00:00<00:00, 253.14it/s]
100%|██████████| 1/1 [00:00<00:00, 252.64it/s]
100%|██████████| 1/1 [00:00<00:00, 246.49it/s]
100%|██████████| 1/1 [00:00<00:00, 233.76it/s]
100%|██████████| 1/1 [00:00<00:00, 227.64it/s]
100%|██████████| 1/1 [00:00<00:00, 217.90it/s]
100%|██████████| 1/1 [00:00<00:00, 209.57it/s]
100%|██████████| 1/1 [00:00<00:00, 208.65it/s]
100%|██████████| 1/1 [00:00<00:00, 198.62it/s]
100%|████████

In [24]:
influence_scores

tensor([ 1.5675e-04,  7.5060e-04,  3.0239e-05, -4.9984e-04, -1.5107e-05,
         4.0858e-05,  1.3132e-05,  5.9237e-04,  8.3166e-05, -2.0216e-04,
         2.3649e-05, -4.2508e-05, -4.7127e-06,  7.4614e-05, -1.8262e-04,
         1.4127e-05, -1.1481e-04,  1.7710e-04,  3.5198e-04, -2.8416e-04],
       device='cuda:7')

In [25]:
# save model and influence scores
torch.save(model.state_dict(), 'mnist_model.pth')
torch.save(influence_scores, 'mnist_influence_scores.pth')

In [19]:
model = generate_model()
ckpt = torch.load('mnist_model.pth')   
model.load_state_dict(ckpt)

<All keys matched successfully>

In [20]:
DEVICE = torch.device("cuda:2")
model.to(DEVICE)

Sequential(
  (0): MLPBlock(
    (linear): Linear(in_features=784, out_features=128, bias=True)
    (relu): ReLU()
  )
  (1): MLPBlock(
    (linear): Linear(in_features=128, out_features=128, bias=True)
    (relu): ReLU()
  )
  (2): MLPBlock(
    (linear): Linear(in_features=128, out_features=10, bias=True)
    (relu): ReLU()
  )
)

In [21]:
from bio_if.modules.pbo import proximal_bregman_objective


In [22]:
pbo_test_losses = []

EPOCHS = 100
for i, dataset in enumerate(datasets_removed):
    model_new = generate_model().to(DEVICE)
    model_new.load_state_dict(torch.load('mnist_model.pth'))
    optimizer = torch.optim.Adam(model_new.parameters(), lr=0.001)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)
    for epoch in range(EPOCHS // 2):
        model_new.train()
        for x, y in tqdm(train_loader, desc=f'Epoch {epoch} Dataset {i}'):
            x, y = x.to(DEVICE), y.to(DEVICE)
            print(x,y)
            optimizer.zero_grad()
            loss = proximal_bregman_objective(x, y, model_new, model, torch.nn.CrossEntropyLoss(), 0.001)
            loss.backward()
            optimizer.step()
    # calculate test loss
    model_new.eval()
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False)
    loss = 0
    for x, y in test_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        y_pred = model_new(x)
        loss += nn.functional.cross_entropy(y_pred, y, reduction='sum').item()
    pbo_test_losses.append(loss / len(test_data))
    print(f'Dataset {i}: test_loss={pbo_test_losses[-1]}')


Epoch 0 Dataset 0:   0%|          | 0/20 [00:00<?, ?it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([5, 4, 1, 6, 3, 1, 8, 5, 5, 3, 8, 3, 6, 9, 0, 7, 4, 7, 2, 6, 9, 7, 1, 8,
        6, 8, 1, 0, 7, 2, 0, 6, 5, 6, 9, 4, 0, 1, 6, 6, 0, 3, 0, 6, 7, 4, 0, 6,
        3, 9, 3, 8, 8, 1, 8, 2, 9, 7, 7, 0, 1, 0, 8, 8, 8, 2, 0, 9, 1, 0, 7, 8,
        4, 9, 2, 6, 9, 7, 8, 2, 6, 6, 3, 0, 7, 6, 1, 1, 5, 6, 7, 1, 4, 6, 5, 0,
        3, 4, 6, 7, 1, 1, 5, 5, 3, 4, 9, 0, 6, 0, 2, 1, 7, 6, 5, 7, 8, 7, 0, 4,
        1, 9, 7, 9, 4, 8, 1, 6, 6, 3, 8, 5, 3, 1, 1, 3, 3, 5, 2, 5, 8, 4, 6, 3,
        9, 8, 1, 0, 0, 0, 9, 1, 9, 2, 2, 5, 2, 9, 0, 7, 1, 9, 7, 6, 7, 8, 5, 7,
        9, 9, 7, 3, 7, 8, 1, 2, 2, 2, 7, 4, 8, 1, 3, 0, 6, 5, 1, 5, 3, 9, 9, 0,
        7, 4, 7, 8, 1, 0, 7, 2, 4, 1, 5, 1, 8, 2, 2, 0, 4, 7, 2, 7, 2, 0, 0, 8,
        3

Epoch 0 Dataset 0:  20%|██        | 4/20 [00:00<00:02,  7.16it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 4, 4, 7, 2, 6, 9, 9, 0, 6, 6, 2, 6, 2, 8, 5, 6, 3, 8, 5, 0, 6, 4, 7,
        7, 9, 8, 6, 3, 9, 5, 9, 1, 7, 6, 8, 9, 6, 3, 3, 7, 5, 7, 4, 7, 9, 1, 2,
        1, 0, 0, 3, 1, 1, 5, 3, 7, 8, 8, 9, 9, 6, 4, 4, 7, 6, 5, 7, 1, 3, 8, 0,
        8, 0, 9, 5, 2, 6, 4, 0, 7, 6, 1, 4, 0, 2, 0, 7, 6, 7, 7, 7, 1, 9, 4, 5,
        5, 6, 6, 3, 5, 2, 7, 5, 8, 3, 9, 3, 2, 5, 2, 2, 4, 9, 3, 1, 7, 6, 7, 2,
        6, 9, 8, 6, 0, 7, 4, 1, 4, 6, 1, 8, 6, 8, 4, 1, 3, 5, 4, 6, 3, 2, 1, 2,
        1, 0, 2, 6, 3, 9, 9, 0, 8, 8, 2, 1, 8, 6, 7, 3, 8, 4, 3, 8, 1, 0, 3, 2,
        3, 6, 6, 9, 0, 7, 1, 6, 6, 8, 8, 3, 9, 5, 4, 2, 3, 5, 9, 1, 3, 0, 1, 1,
        4, 0, 2, 0, 8, 4, 9, 6, 2, 2, 6, 4, 4, 0, 7, 3, 0, 3, 7, 4, 1, 7, 6, 6,
        8

Epoch 0 Dataset 0:  50%|█████     | 10/20 [00:00<00:00, 15.22it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 6, 2, 2, 8, 2, 1, 0, 6, 6, 0, 1, 0, 0, 0, 0, 7, 7, 8, 7, 4, 4, 4, 4,
        1, 9, 7, 8, 1, 1, 0, 7, 8, 3, 0, 5, 3, 8, 2, 4, 1, 7, 3, 5, 5, 2, 1, 7,
        6, 2, 0, 5, 5, 3, 2, 1, 9, 1, 6, 7, 9, 9, 9, 3, 6, 8, 7, 4, 6, 8, 2, 2,
        6, 4, 0, 0, 8, 7, 2, 2, 8, 8, 1, 3, 0, 6, 7, 2, 5, 3, 6, 0, 5, 8, 2, 2,
        5, 1, 3, 5, 6, 1, 2, 8, 9, 6, 9, 3, 6, 6, 4, 6, 5, 7, 0, 5, 0, 4, 6, 0,
        9, 4, 1, 3, 4, 4, 3, 7, 1, 8, 8, 7, 7, 9, 1, 7, 8, 3, 2, 9, 5, 5, 8, 1,
        4, 1, 0, 9, 6, 6, 6, 8, 7, 1, 3, 1, 4, 8, 5, 6, 8, 2, 9, 9, 6, 8, 8, 2,
        0, 7, 0, 5, 8, 5, 7, 5, 2, 9, 2, 2, 2, 8, 1, 5, 8, 0, 3, 5, 8, 7, 3, 1,
        3, 0, 7, 9, 7, 6, 9, 9, 7, 9, 8, 6, 5, 6, 0, 4, 3, 6, 6, 1, 7, 0, 4, 6,
        2

Epoch 0 Dataset 0:  80%|████████  | 16/20 [00:01<00:00, 20.93it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([4, 9, 8, 4, 9, 3, 2, 8, 0, 4, 6, 6, 9, 5, 4, 1, 3, 1, 8, 0, 0, 1, 3, 1,
        9, 6, 5, 8, 0, 0, 9, 7, 2, 3, 9, 9, 1, 9, 0, 5, 2, 7, 6, 0, 9, 5, 0, 6,
        6, 0, 5, 3, 8, 0, 2, 4, 1, 8, 8, 0, 2, 9, 8, 5, 1, 7, 1, 7, 4, 8, 6, 2,
        1, 1, 1, 1, 4, 3, 7, 4, 3, 7, 1, 4, 8, 4, 9, 6, 1, 0, 6, 3, 7, 6, 3, 9,
        3, 6, 2, 4, 6, 4, 0, 7, 5, 2, 5, 6, 5, 0, 3, 3, 7, 2, 8, 1, 9, 8, 7, 5,
        0, 8, 5, 4, 4, 9, 7, 9, 6, 6, 7, 3, 9, 6, 5, 3, 4, 6, 8, 3, 6, 5, 9, 9,
        2, 7, 1, 5, 4, 1, 1, 9, 7, 4, 2, 0, 2, 0, 2, 6, 7, 2, 9, 0, 4, 0, 8, 6,
        6, 6, 0, 7, 1, 4, 2, 9, 2, 9, 1, 3, 0, 8, 4, 6, 1, 3, 3, 4, 0, 5, 1, 7,
        4, 0, 8, 2, 5, 8, 9, 4, 9, 8, 0, 1, 9, 8, 0, 4, 9, 0, 7, 8, 7, 1, 8, 2,
        6

Epoch 0 Dataset 0: 100%|██████████| 20/20 [00:01<00:00, 15.75it/s]


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([8, 0, 7, 6, 0, 6, 7, 7, 7, 5, 5, 2, 6, 7, 8, 8, 5, 6, 1, 8, 6, 2, 5, 8,
        7, 3, 9, 5, 0, 1, 3, 4, 6, 1, 7, 2, 4, 3, 4, 8, 4, 4, 9, 9, 7, 0, 7, 0,
        8, 9, 8, 0, 2, 8, 3, 1, 9, 6, 6, 2, 5, 5, 6, 0, 6, 1, 7, 8, 5, 2, 0, 8,
        8, 5, 3, 2, 4, 3, 0, 9, 0, 8, 1, 7, 7, 5, 0, 9, 5, 9, 1, 8, 3, 9, 9, 5,
        0, 7, 9, 9, 7, 9, 4, 1, 7, 0, 6, 9, 9, 1, 3, 5, 1, 0, 9, 1, 8, 0, 2, 2,
        4, 0, 5, 4, 5, 3, 5, 7, 2, 6, 2, 7, 9, 5, 2, 9, 3, 2, 8, 5, 3, 3, 6, 0,
        2, 8, 8, 5, 4, 7, 1, 0, 4, 2, 7, 2, 1, 6, 3, 7, 0, 3, 3, 9, 0, 2, 4, 9,
        8, 8, 9, 9, 6, 0, 3, 8, 3, 3, 9, 5, 4, 0, 8, 0, 4, 9, 3, 1, 9, 7, 0, 2,
        0, 2, 3, 5, 0, 4, 1, 0, 4, 5, 5, 6, 4, 1, 6, 3, 6, 7, 4, 1, 7, 4, 8, 8,
        1

Epoch 1 Dataset 0:  15%|█▌        | 3/20 [00:00<00:00, 29.14it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([2, 0, 1, 3, 7, 8, 8, 6, 9, 7, 3, 5, 1, 9, 5, 8, 7, 5, 0, 0, 7, 6, 2, 9,
        1, 6, 1, 3, 0, 2, 6, 3, 3, 3, 4, 8, 2, 9, 5, 0, 4, 1, 0, 3, 9, 5, 3, 3,
        6, 3, 6, 6, 6, 8, 3, 3, 8, 4, 8, 9, 7, 6, 9, 9, 1, 8, 0, 8, 9, 2, 3, 9,
        0, 0, 9, 8, 8, 3, 4, 1, 0, 3, 2, 5, 2, 1, 4, 2, 7, 7, 7, 9, 1, 9, 3, 0,
        1, 5, 9, 2, 4, 4, 2, 4, 4, 6, 7, 7, 8, 2, 3, 6, 2, 7, 1, 9, 7, 0, 8, 7,
        8, 8, 8, 9, 3, 3, 0, 3, 7, 1, 4, 4, 2, 5, 0, 5, 8, 1, 8, 9, 0, 9, 7, 2,
        2, 8, 8, 0, 2, 0, 1, 1, 7, 9, 7, 8, 3, 0, 5, 5, 1, 0, 4, 7, 6, 8, 2, 1,
        1, 4, 8, 7, 1, 7, 9, 2, 9, 9, 9, 7, 3, 8, 5, 5, 4, 2, 3, 4, 8, 5, 6, 9,
        5, 1, 4, 9, 6, 6, 2, 6, 1, 5, 6, 7, 2, 5, 3, 8, 5, 2, 8, 1, 8, 7, 9, 9,
        6

Epoch 1 Dataset 0:  30%|███       | 6/20 [00:00<00:00, 29.17it/s]

tensor([1, 3, 6, 5, 6, 8, 6, 9, 8, 2, 9, 4, 5, 7, 9, 6, 1, 0, 2, 0, 5, 3, 6, 0,
        0, 1, 0, 2, 7, 9, 2, 3, 0, 9, 7, 3, 9, 6, 7, 0, 7, 2, 2, 8, 5, 9, 8, 4,
        5, 9, 3, 3, 6, 0, 3, 1, 2, 0, 2, 3, 9, 6, 1, 1, 4, 9, 4, 8, 8, 2, 7, 0,
        2, 0, 7, 8, 1, 0, 3, 3, 5, 3, 5, 1, 6, 8, 1, 3, 8, 6, 3, 9, 4, 8, 7, 4,
        1, 8, 6, 2, 0, 8, 6, 7, 1, 7, 1, 1, 1, 6, 0, 4, 7, 4, 3, 7, 2, 2, 2, 8,
        4, 8, 2, 6, 3, 4, 4, 5, 7, 1, 8, 8, 0, 6, 2, 9, 5, 0, 8, 3, 4, 7, 2, 3,
        0, 3, 0, 4, 2, 3, 1, 5, 2, 5, 5, 4, 4, 0, 1, 4, 8, 8, 0, 0, 0, 7, 4, 2,
        0, 5, 8, 0, 2, 7, 1, 0, 8, 2, 6, 5, 1, 9, 1, 1, 7, 2, 0, 6, 6, 1, 9, 2,
        0, 2, 8, 3, 7, 3, 0, 5, 9, 5, 3, 3, 5, 4, 7, 7, 7, 4, 8, 7, 7, 1, 3, 8,
        4, 4, 8, 2, 3, 3, 5, 3, 7, 1, 5, 8, 4, 8, 3, 2, 0, 8, 6, 8, 2, 3, 9, 4,
        2, 9, 4, 2, 7, 1, 2, 1, 0, 0, 7, 7, 5, 2, 2, 5], device='cuda:2')
0.0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
  

Epoch 1 Dataset 0:  45%|████▌     | 9/20 [00:00<00:00, 29.22it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([2, 8, 6, 5, 5, 8, 0, 5, 2, 0, 4, 0, 6, 5, 1, 2, 4, 3, 1, 1, 0, 7, 2, 5,
        7, 0, 6, 8, 9, 8, 8, 8, 4, 7, 7, 9, 6, 0, 2, 9, 3, 4, 7, 6, 0, 9, 6, 0,
        5, 9, 4, 9, 0, 1, 9, 1, 8, 4, 2, 9, 0, 2, 0, 7, 9, 3, 7, 8, 7, 3, 2, 6,
        1, 2, 8, 1, 3, 9, 7, 5, 3, 8, 2, 5, 6, 3, 6, 5, 7, 5, 1, 7, 5, 5, 9, 8,
        0, 4, 4, 1, 8, 4, 8, 4, 4, 7, 1, 0, 7, 0, 4, 9, 1, 1, 5, 9, 7, 5, 0, 0,
        3, 8, 4, 9, 0, 6, 2, 2, 2, 7, 8, 5, 7, 3, 1, 6, 6, 2, 3, 9, 9, 5, 9, 3,
        9, 3, 8, 3, 5, 9, 6, 6, 8, 0, 0, 3, 6, 8, 7, 4, 7, 4, 1, 0, 3, 5, 4, 5,
        2, 4, 5, 9, 3, 7, 0, 2, 5, 3, 1, 8, 5, 5, 0, 7, 6, 9, 5, 1, 4, 7, 6, 0,
        9, 5, 5, 2, 1, 2, 3, 5, 7, 3, 7, 0, 0, 8, 4, 2, 8, 7, 9, 6, 6, 1, 3, 2,
        9

Epoch 1 Dataset 0:  60%|██████    | 12/20 [00:00<00:00, 29.27it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([6, 3, 4, 2, 3, 3, 9, 5, 8, 2, 5, 9, 5, 7, 9, 2, 2, 6, 7, 7, 2, 3, 3, 6,
        8, 6, 1, 7, 1, 2, 4, 2, 1, 1, 2, 9, 6, 0, 5, 7, 3, 3, 7, 1, 9, 6, 3, 5,
        5, 6, 0, 8, 4, 3, 1, 5, 4, 0, 9, 1, 1, 7, 7, 3, 3, 5, 2, 1, 2, 0, 7, 5,
        3, 3, 7, 9, 6, 9, 7, 1, 8, 1, 6, 6, 7, 7, 4, 2, 5, 6, 9, 3, 3, 0, 4, 4,
        1, 4, 0, 8, 2, 5, 0, 4, 6, 8, 5, 7, 2, 3, 1, 2, 3, 6, 2, 4, 1, 2, 7, 8,
        7, 1, 7, 6, 2, 3, 3, 8, 2, 5, 1, 6, 4, 6, 3, 6, 6, 6, 5, 8, 1, 1, 6, 2,
        4, 3, 6, 7, 5, 1, 3, 0, 1, 1, 1, 1, 1, 1, 4, 2, 0, 2, 0, 9, 4, 8, 1, 7,
        8, 3, 6, 8, 3, 6, 4, 5, 3, 9, 6, 2, 0, 5, 1, 4, 3, 1, 7, 9, 1, 1, 8, 9,
        2, 6, 0, 7, 0, 3, 9, 8, 9, 2, 6, 5, 1, 3, 7, 8, 1, 9, 2, 9, 7, 7, 1, 5,
        5

Epoch 1 Dataset 0:  75%|███████▌  | 15/20 [00:00<00:00, 28.87it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([2, 7, 8, 5, 9, 9, 3, 7, 1, 1, 4, 0, 9, 1, 6, 3, 3, 9, 3, 3, 0, 4, 8, 4,
        8, 4, 8, 7, 8, 7, 6, 1, 8, 4, 8, 8, 6, 0, 7, 6, 4, 7, 7, 1, 8, 8, 3, 5,
        9, 6, 1, 4, 5, 4, 3, 5, 0, 1, 0, 6, 9, 2, 4, 0, 5, 5, 9, 9, 5, 2, 8, 5,
        7, 3, 7, 9, 5, 9, 4, 9, 0, 3, 3, 5, 7, 9, 3, 9, 6, 0, 4, 7, 1, 2, 9, 8,
        2, 8, 5, 9, 6, 7, 9, 6, 2, 1, 7, 1, 9, 1, 1, 3, 7, 0, 6, 0, 6, 4, 6, 2,
        4, 2, 5, 8, 0, 2, 2, 3, 5, 0, 7, 9, 7, 8, 8, 6, 6, 9, 6, 7, 5, 6, 0, 3,
        7, 8, 5, 1, 0, 8, 1, 1, 9, 5, 3, 4, 6, 4, 9, 5, 7, 4, 5, 0, 2, 4, 6, 6,
        3, 2, 3, 3, 7, 5, 7, 8, 3, 1, 8, 6, 4, 3, 9, 0, 5, 9, 8, 8, 0, 1, 1, 3,
        3, 7, 2, 4, 5, 6, 6, 6, 3, 1, 3, 3, 8, 3, 2, 7, 0, 6, 3, 2, 4, 2, 5, 4,
        4

Epoch 1 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 28.00it/s]


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([4, 8, 1, 4, 3, 1, 8, 7, 5, 1, 2, 1, 4, 5, 2, 9, 8, 1, 4, 4, 6, 6, 2, 1,
        6, 4, 6, 0, 0, 5, 3, 2, 8, 8, 6, 9, 5, 9, 2, 7, 5, 3, 7, 8, 2, 5, 9, 9,
        2, 3, 4, 8, 5, 0, 7, 4, 8, 0, 0, 8, 4, 9, 5, 2, 0, 1, 7, 0, 0, 1, 9, 7,
        0, 2, 1, 9, 9, 7, 7, 9, 8, 3, 6, 0, 8, 2, 5, 5, 2, 6, 7, 9, 7, 8, 4, 7,
        0, 9, 5, 3, 5, 5, 5, 5, 2, 5, 9, 4, 0, 5, 8, 0, 4, 7, 3, 0, 9, 2, 0, 8,
        6, 9, 3, 6, 3, 9, 8, 4, 8, 6, 1, 6, 3, 3, 7, 7, 8, 6, 4, 5, 7, 6, 5, 6,
        2, 4, 4, 4, 9, 9, 1, 5, 4, 3, 3, 6, 8, 9, 0, 8, 6, 4, 3, 1, 9, 3, 8, 3,
        6, 2, 7, 5, 5, 0, 8, 7, 9, 7, 7, 9, 3, 1, 2, 6, 0, 6, 6, 9, 4, 3, 7, 6,
        4, 3, 5, 8, 0, 6, 5, 8, 7, 9, 7, 7, 3, 5, 7, 4, 6, 2, 1, 5, 7, 9, 9, 0,
        1

Epoch 2 Dataset 0:   0%|          | 0/20 [00:00<?, ?it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 3, 2, 0, 4, 3, 7, 0, 3, 6, 6, 8, 7, 9, 0, 1, 9, 4, 2, 7, 7, 0, 9, 4,
        3, 1, 8, 2, 6, 5, 3, 6, 1, 0, 4, 5, 7, 7, 2, 5, 1, 1, 5, 1, 9, 9, 1, 6,
        2, 6, 1, 6, 9, 3, 3, 4, 8, 7, 6, 1, 7, 7, 1, 9, 8, 4, 1, 2, 9, 7, 9, 2,
        9, 5, 2, 1, 1, 6, 9, 6, 1, 6, 3, 0, 4, 2, 1, 7, 7, 7, 8, 0, 3, 8, 6, 3,
        5, 5, 3, 7, 9, 1, 8, 4, 4, 5, 1, 6, 9, 7, 3, 3, 0, 2, 9, 6, 9, 1, 9, 2,
        3, 8, 4, 1, 5, 8, 6, 3, 5, 2, 6, 0, 5, 4, 0, 2, 0, 8, 7, 9, 6, 7, 3, 1,
        4, 4, 7, 9, 1, 8, 1, 3, 5, 3, 8, 0, 3, 7, 9, 2, 4, 8, 0, 2, 2, 2, 4, 5,
        6, 0, 3, 6, 2, 3, 9, 4, 5, 5, 7, 8, 8, 2, 0, 2, 6, 6, 3, 5, 2, 0, 4, 5,
        5, 3, 9, 4, 3, 0, 9, 5, 9, 6, 2, 8, 1, 8, 3, 8, 3, 3, 3, 0, 3, 8, 8, 8,
        0

Epoch 2 Dataset 0:  15%|█▌        | 3/20 [00:00<00:00, 27.52it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([8, 2, 5, 7, 0, 4, 7, 8, 7, 3, 6, 0, 7, 3, 5, 7, 6, 4, 9, 3, 7, 2, 4, 8,
        1, 2, 2, 5, 3, 2, 4, 5, 9, 4, 1, 9, 7, 0, 7, 3, 3, 4, 4, 8, 7, 8, 0, 9,
        8, 2, 6, 4, 7, 7, 0, 5, 3, 0, 8, 5, 1, 6, 3, 2, 1, 1, 1, 9, 2, 6, 6, 1,
        3, 0, 9, 8, 5, 0, 5, 3, 5, 1, 8, 9, 4, 1, 5, 7, 1, 7, 8, 5, 3, 7, 3, 0,
        6, 8, 0, 4, 4, 1, 0, 9, 0, 7, 0, 4, 6, 6, 7, 0, 9, 2, 8, 3, 6, 8, 9, 4,
        2, 2, 9, 8, 5, 1, 0, 7, 7, 5, 9, 1, 1, 4, 9, 0, 7, 3, 7, 9, 0, 1, 1, 1,
        3, 2, 4, 9, 3, 6, 2, 8, 9, 3, 4, 2, 2, 8, 0, 0, 7, 1, 8, 5, 1, 9, 0, 2,
        5, 2, 4, 4, 9, 4, 8, 3, 8, 7, 5, 8, 9, 5, 5, 7, 5, 3, 7, 6, 1, 4, 3, 2,
        6, 9, 1, 9, 7, 6, 8, 5, 9, 0, 0, 6, 1, 5, 6, 2, 3, 5, 9, 9, 8, 5, 0, 8,
        1

Epoch 2 Dataset 0:  30%|███       | 6/20 [00:00<00:00, 27.85it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([4, 0, 2, 0, 0, 5, 0, 5, 1, 5, 0, 7, 1, 3, 4, 6, 1, 0, 7, 2, 8, 4, 9, 4,
        0, 8, 2, 8, 0, 1, 1, 3, 9, 6, 9, 5, 7, 4, 8, 9, 5, 2, 4, 9, 9, 6, 8, 2,
        0, 5, 6, 9, 8, 7, 8, 9, 1, 9, 7, 6, 9, 3, 4, 9, 7, 2, 6, 7, 5, 8, 3, 4,
        7, 5, 1, 2, 5, 2, 4, 8, 9, 5, 2, 2, 2, 8, 2, 8, 9, 2, 3, 8, 4, 7, 6, 0,
        7, 6, 2, 1, 1, 9, 8, 3, 5, 3, 5, 2, 9, 1, 4, 0, 1, 4, 3, 1, 6, 2, 0, 4,
        7, 8, 7, 1, 0, 0, 7, 3, 9, 1, 8, 3, 6, 6, 1, 1, 3, 0, 9, 5, 8, 0, 7, 0,
        0, 4, 5, 3, 8, 1, 9, 7, 8, 3, 2, 0, 0, 7, 5, 0, 6, 2, 8, 3, 4, 5, 5, 5,
        5, 0, 9, 7, 6, 7, 8, 0, 0, 0, 7, 5, 5, 4, 9, 8, 6, 1, 9, 7, 3, 4, 8, 1,
        0, 4, 4, 0, 8, 1, 2, 3, 7, 3, 7, 4, 2, 8, 0, 8, 6, 4, 6, 2, 1, 3, 1, 5,
        4

Epoch 2 Dataset 0:  45%|████▌     | 9/20 [00:00<00:00, 27.88it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([9, 5, 5, 8, 7, 7, 1, 2, 0, 9, 0, 3, 6, 0, 9, 9, 0, 5, 7, 6, 6, 4, 7, 4,
        9, 6, 4, 8, 2, 5, 1, 1, 7, 9, 9, 9, 1, 8, 5, 3, 0, 0, 4, 8, 3, 0, 2, 1,
        6, 9, 3, 7, 2, 8, 0, 3, 7, 7, 1, 1, 8, 8, 0, 1, 1, 2, 8, 5, 5, 4, 1, 8,
        5, 4, 3, 9, 7, 2, 5, 9, 7, 7, 8, 0, 3, 2, 8, 7, 1, 9, 4, 1, 1, 8, 0, 2,
        8, 1, 9, 4, 3, 0, 9, 5, 4, 1, 3, 8, 4, 6, 6, 7, 2, 0, 6, 0, 6, 3, 0, 6,
        3, 0, 6, 3, 8, 0, 0, 1, 0, 8, 8, 7, 7, 7, 7, 2, 2, 1, 7, 2, 7, 8, 2, 5,
        0, 3, 7, 6, 8, 3, 3, 4, 8, 9, 4, 9, 2, 0, 3, 3, 8, 8, 4, 3, 4, 3, 9, 7,
        3, 5, 3, 8, 5, 5, 9, 1, 2, 2, 2, 6, 5, 3, 0, 7, 5, 1, 0, 7, 6, 0, 9, 0,
        6, 0, 9, 8, 8, 8, 1, 6, 3, 1, 5, 5, 8, 1, 8, 2, 2, 4, 2, 7, 4, 0, 8, 9,
        0

Epoch 2 Dataset 0:  60%|██████    | 12/20 [00:00<00:00, 28.05it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([7, 9, 2, 1, 6, 2, 5, 6, 5, 9, 4, 6, 2, 1, 7, 0, 6, 2, 2, 2, 7, 9, 6, 3,
        0, 1, 4, 1, 8, 5, 1, 1, 9, 5, 3, 9, 2, 6, 1, 9, 6, 9, 7, 8, 6, 6, 0, 4,
        1, 5, 8, 6, 7, 7, 6, 3, 0, 4, 3, 9, 2, 7, 5, 8, 7, 3, 2, 3, 9, 8, 0, 1,
        5, 5, 1, 0, 9, 1, 7, 4, 5, 2, 0, 5, 6, 3, 7, 3, 5, 4, 5, 0, 0, 4, 8, 2,
        3, 0, 5, 2, 2, 1, 5, 9, 2, 8, 1, 9, 3, 6, 1, 2, 7, 4, 3, 1, 2, 7, 3, 3,
        5, 5, 0, 8, 4, 4, 9, 9, 4, 9, 4, 0, 6, 6, 9, 8, 1, 7, 3, 5, 0, 3, 0, 2,
        2, 1, 5, 6, 3, 8, 3, 4, 7, 6, 4, 4, 0, 2, 9, 9, 7, 2, 6, 8, 7, 4, 7, 8,
        9, 8, 2, 1, 7, 0, 7, 8, 7, 3, 6, 7, 7, 8, 9, 0, 5, 6, 6, 6, 3, 2, 7, 1,
        4, 5, 9, 2, 2, 9, 2, 5, 8, 6, 5, 6, 1, 9, 0, 0, 1, 7, 3, 8, 6, 0, 1, 9,
        7

Epoch 2 Dataset 0:  75%|███████▌  | 15/20 [00:00<00:00, 27.97it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([8, 2, 4, 8, 6, 4, 5, 1, 0, 1, 7, 7, 3, 8, 8, 6, 2, 6, 8, 1, 8, 0, 5, 0,
        3, 8, 4, 3, 5, 4, 9, 3, 1, 8, 5, 9, 4, 0, 6, 4, 9, 6, 5, 1, 7, 1, 7, 6,
        1, 7, 1, 0, 6, 2, 7, 9, 9, 7, 4, 9, 7, 5, 7, 3, 0, 8, 4, 7, 5, 5, 0, 6,
        4, 0, 6, 7, 9, 2, 4, 6, 3, 4, 4, 7, 9, 1, 9, 7, 0, 1, 5, 0, 6, 8, 4, 9,
        7, 7, 3, 5, 2, 4, 6, 6, 5, 3, 3, 6, 9, 7, 8, 7, 8, 0, 5, 2, 9, 1, 7, 4,
        5, 4, 8, 9, 5, 7, 0, 0, 6, 8, 0, 8, 4, 8, 5, 5, 4, 4, 8, 2, 1, 4, 3, 3,
        5, 5, 0, 0, 2, 4, 0, 2, 2, 6, 9, 3, 1, 5, 7, 1, 8, 9, 5, 0, 0, 3, 7, 9,
        1, 6, 0, 2, 2, 6, 7, 1, 1, 8, 9, 6, 4, 3, 3, 8, 5, 4, 2, 9, 2, 6, 8, 2,
        2, 7, 1, 1, 4, 0, 7, 6, 7, 2, 3, 0, 7, 1, 3, 5, 1, 0, 7, 9, 8, 3, 5, 9,
        5

Epoch 2 Dataset 0:  90%|█████████ | 18/20 [00:00<00:00, 21.18it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([8, 0, 4, 8, 5, 9, 4, 7, 5, 4, 2, 2, 6, 2, 9, 6, 3, 3, 8, 6, 6, 6, 5, 1,
        0, 6, 7, 7, 0, 2, 2, 4, 7, 3, 1, 0, 6, 3, 0, 3, 9, 9, 5, 0, 8, 6, 8, 7,
        0, 4, 9, 9, 4, 1, 5, 6, 2, 2, 5, 1, 0, 8, 9, 2, 0, 1, 0, 0, 8, 3, 5, 5,
        1, 5, 7, 0, 2, 9, 5, 2, 9, 9, 3, 5, 6, 1, 7, 2, 0, 8, 1, 4, 6, 0, 4, 5,
        2, 7, 7, 9, 1, 5, 8, 5, 7, 2, 9, 4, 0, 8, 6, 2, 3, 8, 7, 1, 8, 3, 3, 2,
        5, 2, 0, 7, 0, 7, 8, 1, 0, 6, 3, 5, 5, 5, 5, 7, 0, 3, 9, 9, 1, 6, 8, 4,
        4, 3, 1, 5, 7, 2, 3, 4, 2, 7, 9, 0, 9, 2, 2, 8, 6, 7, 7, 7, 0, 4, 5, 0,
        0, 9, 7, 5, 2, 4, 1, 0, 3, 9, 6, 4, 0, 0, 4, 4, 6, 1, 2, 7, 2, 2, 3, 8,
        8, 1, 9, 2, 5, 9, 1, 1, 7, 5, 6, 7, 8, 9, 5, 0, 4, 4, 8, 7, 6, 7, 2, 1,
        0

Epoch 2 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 24.88it/s]


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 6, 9, 7, 7, 2, 0, 9, 4, 9, 6, 3, 9, 1, 0, 4, 0, 9, 9, 9, 1, 6, 2, 4,
        1, 1, 4, 1, 6, 5, 2, 0, 8, 4, 0, 7, 9, 2, 8, 1, 5, 2, 3, 3, 8, 3, 9, 4,
        8, 1, 4, 0, 8, 1, 3, 9, 7, 1, 7, 6, 8, 9, 0, 3, 3, 0, 6, 8, 9, 9, 8, 4,
        2, 1, 1, 5, 0, 9, 4, 7, 4, 3, 0, 6, 8, 4, 8, 6, 7, 5, 1, 8, 0, 1, 7, 0,
        5, 2, 9, 0, 6, 1, 2, 0, 8, 1, 7, 0, 3, 0, 6, 6, 6, 4, 2, 5, 2, 7, 2, 8,
        4, 1, 8, 1, 7, 6, 8, 4, 5, 5, 0, 1, 4, 2, 3], device='cuda:2')
0.0


Epoch 3 Dataset 0:   0%|          | 0/20 [00:00<?, ?it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([6, 8, 8, 0, 4, 3, 9, 4, 5, 3, 6, 5, 1, 0, 0, 3, 7, 8, 6, 8, 4, 5, 7, 3,
        7, 0, 0, 3, 3, 2, 5, 8, 9, 7, 7, 8, 5, 9, 8, 6, 8, 8, 4, 4, 0, 8, 4, 8,
        4, 9, 1, 1, 5, 2, 2, 1, 7, 6, 9, 3, 4, 8, 6, 3, 7, 1, 6, 5, 5, 8, 0, 4,
        9, 5, 2, 3, 3, 9, 1, 0, 7, 4, 8, 6, 6, 9, 8, 9, 7, 8, 9, 9, 2, 3, 8, 7,
        2, 7, 3, 3, 3, 0, 3, 5, 7, 5, 0, 7, 7, 4, 4, 2, 9, 8, 2, 7, 1, 4, 2, 0,
        7, 9, 0, 4, 6, 2, 1, 8, 2, 2, 1, 2, 0, 8, 4, 3, 5, 0, 9, 6, 0, 0, 6, 7,
        0, 7, 2, 4, 7, 7, 3, 2, 7, 9, 0, 2, 6, 1, 1, 4, 7, 7, 1, 5, 5, 3, 6, 4,
        6, 7, 5, 6, 3, 5, 5, 1, 3, 3, 9, 8, 4, 4, 3, 7, 7, 6, 0, 1, 1, 1, 3, 5,
        2, 7, 6, 7, 4, 0, 3, 0, 2, 6, 8, 2, 5, 2, 1, 6, 0, 0, 3, 5, 8, 8, 0, 3,
        8

Epoch 3 Dataset 0:  15%|█▌        | 3/20 [00:00<00:00, 28.26it/s]

0.0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 1, 0, 5, 6, 5, 5, 5, 8, 0, 6, 6, 8, 9, 2, 2, 4, 1, 9, 0, 8, 7, 5, 4,
        4, 6, 6, 8, 7, 6, 0, 0, 7, 6, 1, 3, 0, 3, 0, 6, 6, 0, 4, 7, 8, 9, 6, 6,
        7, 8, 9, 7, 8, 2, 1, 8, 5, 8, 1, 8, 2, 7, 5, 9, 6, 6, 1, 0, 8, 5, 1, 2,
        1, 7, 1, 8, 8, 2, 1, 6, 9, 4, 6, 4, 8, 4, 2, 9, 9, 4, 7, 5, 0, 9, 3, 9,
        7, 7, 6, 6, 1, 1, 7, 7, 6, 0, 8, 6, 2, 8, 3, 8, 9, 1, 3, 5, 1, 5, 3, 8,
        2, 6, 3, 9, 1, 3, 0, 1, 6, 1, 5, 4, 3, 9, 0, 5, 9, 8, 4, 1, 3, 0, 4, 8,
        8, 0, 2, 0, 6, 5, 8, 2, 7, 0, 9, 8, 7, 2, 0, 0, 4, 3, 1, 6, 7, 9, 6, 4,
        1, 0, 3, 3, 3, 5, 7, 0, 7, 0, 4, 3, 1, 6, 2, 2, 1, 8, 1, 2, 6, 6, 1, 6,
        2, 2, 8, 4, 8, 5, 6, 6, 0, 3, 8, 5, 7, 3, 6, 7, 4, 1, 4, 2, 4, 8, 2, 7,
     

Epoch 3 Dataset 0:  35%|███▌      | 7/20 [00:00<00:00, 30.46it/s]

0.0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([4, 8, 4, 0, 0, 6, 4, 8, 0, 2, 4, 0, 7, 9, 4, 6, 8, 4, 4, 7, 6, 0, 3, 1,
        4, 1, 8, 3, 8, 9, 0, 0, 2, 6, 2, 5, 6, 9, 6, 4, 8, 6, 2, 5, 1, 5, 0, 8,
        6, 9, 1, 8, 8, 7, 2, 7, 7, 6, 2, 8, 2, 2, 4, 0, 4, 2, 5, 4, 0, 6, 6, 7,
        9, 1, 0, 2, 5, 9, 0, 0, 6, 4, 8, 4, 7, 0, 7, 7, 7, 8, 8, 5, 6, 1, 8, 9,
        9, 9, 2, 2, 3, 2, 6, 8, 8, 1, 8, 6, 2, 5, 3, 1, 1, 5, 3, 1, 0, 4, 2, 8,
        6, 5, 4, 2, 5, 3, 5, 4, 0, 1, 9, 3, 8, 7, 9, 6, 6, 9, 5, 4, 0, 7, 1, 7,
        6, 6, 5, 6, 0, 1, 8, 0, 4, 2, 4, 4, 6, 3, 9, 9, 4, 2, 4, 2, 0, 2, 8, 6,
        1, 3, 6, 2, 5, 9, 0, 5, 3, 5, 5, 6, 0, 8, 0, 6, 6, 3, 7, 5, 8, 1, 2, 1,
        7, 5, 5, 5, 8, 1, 2, 9, 6, 5, 7, 4, 6, 6, 6, 5, 7, 3, 3, 0, 2, 6, 7, 9,
     

Epoch 3 Dataset 0:  55%|█████▌    | 11/20 [00:00<00:00, 31.20it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([1, 7, 1, 0, 6, 3, 3, 5, 6, 9, 3, 5, 0, 6, 6, 3, 4, 0, 5, 0, 0, 6, 3, 9,
        4, 7, 9, 1, 5, 1, 4, 4, 1, 5, 0, 9, 6, 8, 1, 6, 6, 4, 8, 5, 8, 8, 7, 6,
        4, 3, 6, 8, 1, 6, 8, 0, 3, 1, 8, 4, 6, 0, 5, 3, 2, 5, 9, 7, 1, 9, 3, 0,
        7, 5, 6, 6, 2, 0, 6, 0, 1, 4, 2, 1, 0, 0, 2, 7, 6, 7, 0, 0, 2, 9, 1, 6,
        0, 1, 9, 7, 8, 0, 3, 3, 7, 7, 5, 2, 0, 0, 4, 9, 0, 5, 3, 3, 3, 7, 8, 9,
        9, 0, 9, 8, 8, 8, 3, 7, 5, 5, 9, 3, 5, 3, 1, 1, 7, 5, 6, 1, 9, 3, 3, 1,
        8, 4, 8, 8, 6, 7, 1, 3, 1, 4, 9, 9, 2, 5, 1, 0, 3, 3, 7, 2, 9, 8, 0, 1,
        9, 0, 0, 1, 2, 1, 6, 0, 5, 3, 0, 7, 9, 8, 0, 8, 4, 3, 5, 7, 6, 9, 9, 5,
        8, 7, 0, 0, 7, 6, 2, 8, 3, 3, 9, 7, 8, 8, 6, 4, 9, 5, 2, 3, 3, 9, 2, 3,
        9

Epoch 3 Dataset 0:  75%|███████▌  | 15/20 [00:00<00:00, 31.55it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([7, 9, 5, 8, 4, 0, 4, 7, 8, 8, 3, 7, 9, 3, 2, 2, 1, 8, 4, 6, 9, 5, 9, 7,
        4, 0, 3, 1, 0, 7, 5, 7, 1, 8, 0, 9, 6, 0, 6, 0, 0, 4, 0, 1, 8, 7, 3, 5,
        4, 1, 3, 0, 9, 2, 7, 8, 0, 1, 8, 9, 6, 7, 7, 6, 2, 7, 2, 4, 7, 9, 1, 7,
        3, 7, 3, 2, 6, 2, 7, 6, 4, 9, 4, 8, 1, 6, 1, 8, 6, 3, 1, 2, 9, 8, 1, 3,
        9, 5, 9, 7, 8, 2, 1, 8, 6, 2, 3, 9, 3, 4, 5, 1, 7, 6, 4, 9, 2, 7, 7, 9,
        7, 6, 3, 0, 9, 1, 8, 6, 8, 9, 8, 6, 5, 4, 4, 3, 8, 8, 4, 0, 4, 4, 9, 8,
        7, 1, 2, 6, 0, 7, 6, 5, 5, 4, 6, 2, 2, 3, 4, 9, 2, 4, 0, 4, 9, 8, 2, 4,
        9, 3, 4, 1, 8, 0, 7, 5, 0, 4, 8, 2, 7, 6, 1, 6, 5, 1, 7, 9, 6, 8, 2, 5,
        1, 1, 3, 9, 9, 2, 1, 3, 6, 4, 8, 8, 7, 0, 1, 1, 6, 7, 0, 0, 6, 9, 2, 2,
        3

Epoch 3 Dataset 0:  95%|█████████▌| 19/20 [00:00<00:00, 31.70it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([0, 2, 6, 1, 6, 1, 9, 3, 6, 5, 3, 5, 6, 5, 7, 3, 9, 0, 5, 8, 2, 2, 2, 8,
        3, 0, 3, 4, 1, 1, 8, 0, 5, 1, 2, 2, 7, 3, 5, 3, 2, 8, 2, 6, 3, 5, 1, 7,
        1, 3, 7, 8, 9, 3, 9, 9, 1, 8, 1, 4, 3, 7, 1, 4, 0, 9, 5, 4, 0, 7, 1, 7,
        2, 6, 1, 9, 7, 3, 5, 7, 1, 5, 3, 9, 5, 3, 4, 2, 4, 8, 3, 7, 2, 2, 8, 9,
        9, 1, 8, 6, 3, 3, 7, 0, 5, 8, 7, 4, 3, 4, 1, 5, 8, 1, 7, 3, 5, 0, 6, 1,
        0, 7, 1, 1, 7, 0, 8, 1, 0, 9, 1, 0, 4, 5, 5, 6, 6, 8, 4, 4, 9, 4, 0, 1,
        1, 8, 4, 7, 7, 6, 8, 3, 4, 3, 5, 4, 9, 3, 8, 0, 1, 7, 1, 3, 2, 0, 2, 2,
        4, 2, 3, 7, 1, 3, 8, 1, 8, 0, 0, 2, 4, 9, 5, 9, 7, 8, 5, 5, 7, 6, 6, 2,
        3, 5, 8, 0, 9, 2, 3, 0, 7, 7, 3, 2, 5, 1, 4, 8, 5, 6, 8, 4, 1, 1, 2, 0,
        6

Epoch 3 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 31.81it/s]


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 2, 8, 6, 0, 5, 0, 7, 9, 0, 3, 3, 7, 0, 7, 6, 7, 9, 2, 7, 3, 8, 5, 9,
        8, 1, 4, 2, 2, 1, 1, 9, 9, 9, 1, 9, 6, 1, 4, 7, 7, 6, 2, 4, 6, 2, 7, 5,
        2, 3, 8, 7, 5, 5, 4, 3, 9, 1, 7, 9, 6, 6, 5, 0, 9, 8, 2, 0, 4, 9, 2, 5,
        6, 7, 1, 9, 3, 8, 9, 9, 4, 6, 1, 7, 3, 5, 5, 4, 7, 6, 0, 1, 3, 4, 7, 8,
        2, 6, 4, 9, 3, 4, 2, 5, 8, 2, 0, 4, 7, 5, 8, 6, 3, 0, 0, 7, 6, 3, 8, 8,
        8, 1, 9, 7, 3, 5, 1, 0, 0, 7, 8, 3, 7, 5, 5], device='cuda:2')
0.0


Epoch 4 Dataset 0:   0%|          | 0/20 [00:00<?, ?it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([2, 3, 0, 3, 9, 1, 5, 4, 5, 3, 9, 9, 7, 8, 4, 0, 0, 0, 3, 3, 9, 9, 6, 6,
        8, 2, 5, 8, 1, 8, 3, 7, 5, 1, 6, 8, 8, 5, 6, 3, 7, 8, 3, 9, 9, 0, 1, 7,
        4, 8, 4, 1, 0, 4, 1, 0, 1, 9, 4, 9, 9, 0, 8, 7, 2, 0, 9, 7, 3, 0, 4, 0,
        5, 2, 0, 0, 3, 5, 2, 6, 1, 6, 7, 2, 3, 3, 7, 0, 0, 1, 6, 0, 0, 0, 2, 2,
        6, 7, 9, 2, 8, 9, 8, 1, 9, 8, 9, 2, 3, 8, 3, 1, 2, 9, 1, 1, 1, 4, 1, 7,
        4, 3, 1, 5, 2, 9, 8, 7, 5, 3, 9, 9, 2, 7, 1, 1, 3, 0, 2, 4, 5, 5, 0, 8,
        6, 3, 3, 7, 7, 5, 3, 2, 8, 7, 6, 9, 8, 1, 8, 6, 2, 1, 0, 0, 1, 1, 5, 2,
        8, 2, 0, 9, 8, 1, 6, 1, 3, 4, 4, 7, 0, 5, 5, 6, 6, 5, 4, 7, 0, 3, 2, 1,
        8, 7, 6, 0, 3, 7, 2, 1, 6, 7, 1, 6, 0, 4, 0, 4, 5, 3, 5, 3, 7, 7, 0, 0,
        8

Epoch 4 Dataset 0:  20%|██        | 4/20 [00:00<00:00, 30.76it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([6, 0, 3, 9, 9, 9, 9, 8, 2, 3, 9, 9, 5, 9, 9, 5, 3, 9, 8, 3, 8, 6, 9, 7,
        5, 3, 9, 7, 0, 3, 2, 8, 0, 4, 8, 9, 3, 1, 2, 8, 9, 0, 4, 5, 6, 1, 0, 2,
        4, 7, 8, 8, 0, 0, 3, 5, 4, 3, 3, 9, 6, 9, 8, 4, 2, 1, 6, 4, 5, 2, 7, 5,
        0, 0, 9, 1, 2, 8, 9, 7, 1, 0, 5, 6, 5, 5, 6, 5, 9, 2, 6, 2, 5, 8, 0, 1,
        9, 9, 1, 0, 7, 0, 6, 6, 3, 9, 6, 2, 9, 8, 0, 7, 7, 5, 6, 3, 7, 3, 5, 3,
        6, 2, 3, 5, 3, 3, 9, 9, 2, 6, 4, 5, 3, 1, 2, 5, 4, 2, 3, 1, 1, 1, 4, 7,
        4, 7, 8, 0, 6, 0, 8, 2, 4, 8, 2, 5, 1, 1, 0, 5, 4, 6, 7, 1, 1, 3, 0, 9,
        2, 1, 5, 3, 4, 5, 6, 2, 2, 2, 2, 6, 5, 2, 5, 4, 3, 4, 6, 4, 2, 3, 3, 8,
        3, 8, 8, 5, 2, 3, 6, 7, 2, 6, 3, 8, 5, 5, 1, 2, 7, 3, 9, 9, 6, 0, 9, 7,
        8

Epoch 4 Dataset 0:  40%|████      | 8/20 [00:00<00:00, 30.14it/s]

tensor([6, 4, 9, 2, 9, 7, 6, 4, 7, 4, 4, 2, 6, 5, 1, 1, 4, 3, 9, 8, 4, 3, 3, 5,
        1, 2, 0, 8, 9, 8, 3, 7, 1, 8, 6, 3, 4, 9, 8, 1, 5, 8, 0, 6, 7, 9, 3, 6,
        5, 5, 8, 0, 9, 7, 7, 7, 8, 6, 1, 9, 2, 9, 2, 8, 0, 8, 8, 7, 4, 1, 8, 0,
        8, 1, 7, 3, 0, 0, 8, 8, 6, 9, 4, 6, 8, 5, 9, 0, 1, 1, 0, 4, 9, 1, 9, 1,
        1, 3, 0, 1, 1, 5, 4, 7, 5, 4, 3, 5, 9, 2, 1, 5, 6, 6, 6, 9, 4, 5, 6, 0,
        5, 1, 9, 3, 6, 0, 6, 2, 8, 4, 3, 2, 3, 2, 4, 1, 1, 7, 9, 6, 0, 4, 8, 3,
        1, 8, 9, 0, 2, 7, 7, 2, 2, 2, 9, 6, 3, 9, 7, 8, 6, 0, 3, 7, 6, 2, 3, 2,
        6, 7, 8, 3, 0, 8, 9, 6, 9, 9, 5, 8, 3, 0, 3, 5, 4, 9, 7, 8, 1, 2, 6, 9,
        7, 0, 0, 8, 6, 7, 7, 3, 3, 8, 4, 7, 0, 7, 8, 6, 3, 8, 9, 6, 3, 1, 7, 7,
        3, 0, 3, 2, 5, 2, 4, 8, 3, 9, 2, 9, 8, 6, 3, 6, 5, 4, 4, 0, 2, 1, 1, 1,
        2, 0, 0, 2, 3, 6, 0, 0, 5, 9, 7, 4, 4, 8, 8, 7], device='cuda:2')
0.0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
  

Epoch 4 Dataset 0:  60%|██████    | 12/20 [00:00<00:00, 30.99it/s]

0.0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([4, 4, 0, 6, 6, 2, 6, 4, 6, 2, 9, 4, 9, 1, 6, 0, 6, 4, 9, 7, 0, 4, 5, 4,
        3, 5, 1, 8, 5, 3, 8, 6, 8, 0, 7, 1, 4, 6, 8, 5, 1, 2, 7, 4, 2, 8, 1, 6,
        5, 2, 4, 6, 6, 4, 4, 7, 9, 4, 0, 9, 4, 1, 2, 0, 5, 0, 3, 5, 0, 8, 1, 3,
        6, 7, 4, 0, 1, 6, 4, 5, 2, 7, 7, 4, 0, 1, 1, 2, 2, 8, 0, 6, 8, 4, 3, 5,
        2, 8, 2, 9, 2, 2, 2, 9, 7, 3, 1, 9, 8, 3, 7, 8, 7, 4, 1, 0, 4, 8, 0, 4,
        8, 3, 7, 7, 9, 0, 2, 5, 0, 2, 9, 4, 2, 6, 1, 0, 3, 0, 9, 5, 7, 5, 0, 1,
        8, 2, 3, 8, 4, 6, 0, 6, 8, 3, 7, 5, 9, 8, 1, 9, 1, 3, 3, 2, 8, 8, 0, 0,
        7, 1, 7, 3, 4, 0, 3, 6, 1, 0, 0, 5, 1, 3, 1, 6, 5, 3, 5, 1, 8, 6, 2, 5,
        3, 9, 9, 4, 3, 7, 7, 0, 5, 7, 2, 6, 4, 7, 1, 0, 9, 7, 3, 1, 8, 9, 4, 0,
     

Epoch 4 Dataset 0:  80%|████████  | 16/20 [00:00<00:00, 31.38it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([8, 7, 1, 9, 2, 9, 2, 8, 8, 6, 8, 9, 3, 9, 0, 9, 6, 3, 0, 2, 4, 4, 0, 9,
        0, 5, 5, 2, 0, 9, 3, 3, 4, 4, 7, 0, 5, 9, 8, 0, 1, 8, 1, 2, 3, 5, 4, 3,
        3, 0, 4, 4, 6, 0, 9, 8, 3, 1, 9, 4, 6, 0, 4, 4, 4, 7, 6, 1, 4, 5, 2, 5,
        6, 6, 3, 7, 6, 8, 2, 4, 0, 6, 1, 1, 5, 4, 9, 6, 6, 5, 8, 8, 0, 4, 1, 6,
        5, 3, 5, 7, 5, 8, 8, 2, 1, 4, 4, 2, 1, 9, 0, 8, 9, 1, 3, 9, 6, 5, 8, 8,
        0, 1, 7, 3, 6, 7, 9, 1, 6, 2, 2, 5, 1, 3, 4, 6, 0, 7, 3, 9, 2, 5, 2, 2,
        3, 3, 9, 0, 6, 6, 4, 3, 7, 6, 6, 1, 7, 8, 0, 7, 1, 4, 1, 7, 2, 6, 2, 1,
        1, 9, 1, 6, 7, 0, 4, 9, 5, 4, 2, 5, 7, 2, 2, 6, 7, 8, 8, 1, 7, 1, 9, 1,
        4, 3, 5, 9, 6, 3, 5, 5, 8, 5, 1, 2, 3, 1, 8, 4, 6, 6, 4, 4, 1, 1, 5, 2,
        7

Epoch 4 Dataset 0: 100%|██████████| 20/20 [00:00<00:00, 31.96it/s]


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([5, 0, 8, 8, 4, 0, 0, 2, 8, 6, 1, 1, 3, 3, 2, 2, 4, 7, 7, 9, 9, 6, 4, 7,
        4, 8, 2, 7, 0, 5, 1, 7, 6, 3, 3, 9, 8, 9, 3, 5, 4, 0, 2, 6, 4, 9, 8, 7,
        3, 8, 5, 5, 3, 8, 8, 5, 2, 9, 9, 8, 5, 7, 9, 7, 0, 0, 0, 1, 8, 0, 4, 2,
        8, 7, 0, 8, 7, 4, 4, 4, 0, 0, 2, 0, 6, 2, 1, 6, 0, 7, 9, 4, 8, 0, 9, 6,
        1, 5, 1, 4, 7, 5, 0, 3, 0, 9, 2, 1, 5, 6, 7, 1, 1, 3, 5, 9, 5, 4, 9, 8,
        0, 8, 7, 0, 8, 2, 2, 8, 4, 3, 0, 9, 8, 2, 3], device='cuda:2')
0.0


Epoch 5 Dataset 0:   0%|          | 0/20 [00:00<?, ?it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([4, 4, 0, 2, 8, 5, 6, 3, 2, 9, 6, 6, 2, 2, 6, 6, 0, 7, 4, 5, 3, 8, 0, 5,
        3, 9, 0, 0, 4, 1, 3, 7, 9, 9, 8, 5, 3, 8, 1, 1, 7, 1, 3, 1, 4, 8, 4, 4,
        0, 2, 5, 4, 0, 1, 1, 1, 2, 0, 1, 9, 7, 3, 3, 7, 1, 4, 6, 2, 9, 8, 1, 3,
        7, 3, 5, 4, 3, 8, 0, 9, 5, 0, 0, 8, 6, 3, 5, 1, 5, 9, 4, 2, 2, 1, 7, 8,
        1, 8, 6, 2, 6, 8, 6, 1, 6, 1, 8, 6, 2, 2, 1, 5, 9, 8, 9, 2, 3, 6, 3, 2,
        0, 9, 2, 4, 2, 0, 0, 0, 2, 5, 9, 7, 4, 2, 1, 1, 8, 7, 8, 0, 8, 7, 2, 5,
        4, 6, 2, 1, 4, 9, 1, 7, 6, 3, 6, 2, 4, 0, 0, 2, 7, 4, 2, 4, 1, 0, 3, 2,
        8, 5, 6, 3, 2, 0, 8, 1, 2, 0, 2, 6, 3, 9, 3, 9, 6, 6, 5, 2, 9, 2, 6, 3,
        3, 7, 9, 6, 4, 6, 0, 5, 2, 4, 7, 3, 8, 0, 6, 6, 7, 5, 7, 6, 8, 5, 9, 4,
        0

Epoch 5 Dataset 0:  15%|█▌        | 3/20 [00:00<00:00, 29.55it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([0, 5, 3, 4, 1, 3, 3, 6, 7, 9, 7, 4, 4, 6, 9, 0, 0, 3, 4, 8, 4, 6, 4, 4,
        0, 4, 3, 6, 6, 0, 4, 5, 6, 3, 2, 1, 6, 9, 4, 2, 6, 8, 7, 8, 9, 1, 7, 1,
        4, 6, 7, 5, 5, 3, 5, 5, 9, 8, 3, 6, 7, 4, 8, 3, 8, 3, 2, 1, 2, 1, 2, 3,
        1, 1, 0, 8, 4, 8, 4, 1, 5, 4, 5, 6, 0, 4, 8, 6, 1, 3, 2, 3, 5, 2, 8, 7,
        6, 2, 9, 3, 7, 2, 0, 0, 6, 3, 1, 9, 8, 4, 8, 7, 9, 8, 0, 3, 8, 7, 9, 8,
        5, 0, 1, 7, 9, 9, 7, 8, 4, 3, 8, 9, 0, 4, 2, 4, 7, 5, 1, 6, 1, 3, 7, 4,
        0, 3, 7, 5, 7, 3, 6, 9, 8, 9, 1, 1, 7, 2, 9, 4, 1, 1, 2, 1, 9, 0, 7, 2,
        7, 2, 3, 7, 4, 2, 0, 0, 0, 1, 8, 0, 9, 9, 9, 6, 6, 6, 9, 4, 5, 3, 3, 2,
        8, 8, 1, 2, 5, 4, 6, 3, 2, 6, 8, 5, 0, 1, 6, 1, 0, 9, 8, 7, 1, 7, 9, 5,
        7

Epoch 5 Dataset 0:  35%|███▌      | 7/20 [00:00<00:00, 31.47it/s]

0.0
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([2, 6, 0, 9, 6, 1, 9, 9, 6, 4, 9, 8, 3, 0, 0, 8, 9, 4, 5, 4, 4, 5, 6, 6,
        3, 5, 1, 2, 9, 0, 6, 3, 2, 2, 9, 2, 6, 7, 8, 2, 1, 0, 2, 6, 3, 2, 8, 5,
        0, 2, 0, 4, 8, 9, 7, 8, 2, 9, 4, 9, 7, 1, 2, 7, 0, 5, 6, 5, 1, 8, 8, 2,
        0, 8, 5, 0, 2, 6, 9, 3, 4, 2, 6, 3, 3, 5, 6, 7, 7, 2, 0, 8, 7, 2, 8, 8,
        2, 3, 3, 7, 7, 9, 8, 6, 5, 1, 8, 8, 4, 8, 6, 5, 5, 1, 7, 4, 5, 8, 2, 9,
        3, 4, 3, 0, 2, 0, 3, 9, 7, 7, 3, 9, 6, 5, 7, 0, 2, 0, 2, 5, 9, 3, 8, 5,
        0, 9, 2, 4, 3, 1, 9, 9, 1, 9, 9, 6, 0, 1, 3, 5, 6, 0, 1, 7, 3, 9, 0, 1,
        4, 2, 1, 8, 2, 5, 7, 0, 6, 4, 6, 6, 8, 6, 9, 8, 8, 4, 1, 9, 1, 7, 3, 0,
        7, 7, 2, 5, 2, 0, 0, 8, 5, 7, 2, 0, 7, 3, 7, 8, 8, 7, 8, 6, 0, 1, 8, 7,
     

Epoch 5 Dataset 0:  55%|█████▌    | 11/20 [00:00<00:00, 32.33it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([7, 1, 5, 0, 3, 6, 9, 8, 8, 5, 3, 6, 1, 7, 3, 2, 6, 8, 7, 9, 6, 2, 3, 4,
        0, 9, 1, 7, 0, 3, 0, 2, 7, 8, 3, 4, 4, 7, 9, 7, 5, 7, 4, 1, 5, 5, 0, 8,
        0, 7, 4, 1, 7, 5, 9, 8, 0, 1, 6, 6, 3, 7, 1, 1, 3, 9, 5, 0, 2, 5, 2, 9,
        6, 9, 7, 7, 4, 7, 3, 1, 4, 1, 2, 9, 3, 0, 7, 0, 4, 5, 4, 6, 8, 9, 7, 1,
        9, 1, 3, 4, 2, 9, 0, 7, 6, 0, 9, 9, 9, 7, 2, 8, 9, 4, 3, 4, 0, 2, 2, 2,
        8, 5, 6, 8, 7, 7, 2, 1, 0, 7, 0, 2, 5, 0, 6, 7, 5, 8, 1, 9, 8, 1, 6, 3,
        6, 6, 2, 9, 7, 5, 4, 3, 3, 6, 7, 2, 3, 7, 1, 6, 5, 9, 5, 2, 9, 1, 6, 9,
        4, 7, 2, 2, 0, 0, 8, 7, 5, 6, 9, 2, 3, 1, 5, 3, 7, 4, 1, 1, 0, 0, 8, 9,
        7, 5, 9, 2, 3, 4, 6, 1, 3, 0, 4, 5, 2, 0, 7, 4, 5, 0, 7, 8, 8, 6, 8, 8,
        5

Epoch 5 Dataset 0:  75%|███████▌  | 15/20 [00:00<00:00, 32.72it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([3, 2, 1, 8, 8, 0, 0, 1, 7, 1, 2, 4, 9, 1, 5, 1, 8, 9, 2, 4, 5, 6, 3, 1,
        3, 8, 6, 6, 1, 1, 5, 6, 7, 2, 1, 4, 4, 7, 8, 6, 0, 1, 8, 0, 6, 2, 9, 5,
        6, 1, 1, 1, 8, 0, 8, 7, 1, 5, 7, 5, 3, 3, 1, 1, 9, 8, 7, 6, 3, 0, 7, 5,
        8, 2, 8, 8, 7, 6, 9, 8, 1, 0, 4, 5, 7, 9, 3, 0, 9, 1, 6, 4, 1, 9, 8, 1,
        8, 8, 9, 2, 7, 6, 3, 3, 3, 4, 3, 1, 2, 1, 7, 6, 7, 2, 5, 9, 9, 6, 9, 2,
        5, 3, 2, 9, 5, 2, 1, 7, 3, 9, 6, 1, 3, 5, 2, 4, 0, 2, 3, 9, 1, 3, 9, 9,
        8, 0, 2, 0, 4, 0, 0, 8, 1, 6, 8, 6, 9, 6, 6, 0, 1, 1, 7, 9, 4, 5, 9, 4,
        1, 7, 5, 2, 8, 1, 9, 6, 2, 4, 4, 6, 0, 4, 2, 7, 6, 7, 5, 5, 9, 6, 9, 8,
        2, 9, 1, 8, 8, 1, 0, 1, 8, 2, 8, 3, 6, 8, 8, 9, 4, 6, 7, 8, 8, 0, 8, 0,
        9

Epoch 5 Dataset 0:  85%|████████▌ | 17/20 [00:00<00:00, 30.51it/s]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:2') tensor([6, 5, 3, 8, 4, 7, 5, 6, 4, 9, 2, 4, 1, 3, 4, 9, 7, 5, 5, 9, 7, 0, 4, 7,
        1, 3, 5, 3, 3, 0, 9, 4, 2, 0, 5, 4, 8, 8, 7, 5, 3, 7, 2, 6, 6, 8, 3, 3,
        6, 0, 1, 3, 1, 4, 2, 4, 5, 7, 8, 9, 7, 4, 3, 7, 9, 5, 8, 1, 3, 2, 0, 2,
        7, 7, 8, 4, 9, 9, 8, 0, 2, 4, 4, 5, 5, 3, 0, 3, 2, 8, 3, 6, 7, 7, 9, 7,
        4, 2, 4, 8, 1, 5, 4, 6, 9, 4, 7, 2, 1, 1, 9, 7, 2, 5, 7, 1, 5, 0, 9, 6,
        6, 6, 0, 8, 0, 2, 5, 6, 0, 1, 5, 9, 8, 1, 6, 3, 0, 5, 3, 3, 8, 8, 1, 8,
        0, 9, 3, 8, 6, 1, 6, 0, 2, 2, 4, 9, 1, 2, 3, 5, 2, 6, 2, 6, 1, 9, 5, 6,
        2, 7, 7, 4, 0, 3, 8, 3, 8, 1, 7, 0, 2, 0, 8, 9, 8, 3, 5, 3, 2, 2, 7, 5,
        3, 5, 0, 0, 9, 0, 1, 7, 5, 8, 0, 7, 3, 4, 1, 9, 4, 7, 7, 1, 8, 6, 9, 7,
        8




KeyboardInterrupt: 

In [None]:
print(pbo_test_losses)
torch.save(pbo_test_losses, 'mnist_pbo_test_losses.pth')

In [None]:
from scipy.stats import spearmanr, pearsonr

In [None]:
# calculate correlation between influence scores and PBO test losses
print(f"Spearman's rho: {spearmanr(influence_scores, pbo_test_losses)}")
print(f"Pearson's r: {pearsonr(influence_scores, pbo_test_losses)}")