In [1]:
import torch.nn as nn
from tqdm import tqdm
from e3nn.nn import _fc
import torch
from torch_geometric.data.data import Data
from models.geometricv2 import GeometricNetV2
from dataset import Dataset, create_transform
from train import load_indexes
from torch_geometric.loader import DataLoader
from IPython.display import clear_output

DATA_DIR = 'data/ala_dipep'
TEST_ON = 'all'
N_SAMPLES = 2000
BATCH_SIZE = 10
LABELS_FILE = 'phi-psi-free-energy.txt'
BONDS_FILE = 'ala_dipep_bonds.csv'
PARTIAL_CHARGES_FILE = 'ala_dipep_partial_charges.csv'
ENERGY_LEVELS = 40
LR = 1e-3
CHECKPOINT = None #'tmpdir-all-1124-1745/checkpoint/epoch=199-step=41799.ckpt'
WEIGHTS = None

In [2]:
train_indexes, validation_indexes, test_indexes = load_indexes(DATA_DIR, TEST_ON, n_samples=N_SAMPLES)
transform = create_transform(DATA_DIR, LABELS_FILE, BONDS_FILE, PARTIAL_CHARGES_FILE, use_dihedrals=False, energy_levels=ENERGY_LEVELS)

test_dataset = Dataset(data_dir=DATA_DIR, indexes=test_indexes, transform=transform, use_dihedrals=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

loading indexes
indexes loaded


In [3]:
if CHECKPOINT:
        assert not WEIGHTS, 'Params --weights and --checkpoint are mutually exclusive'
        print(f'Loading model from checkpoint: {CHECKPOINT}')
        model = GeometricNetV2.load_from_checkpoint(strict=False, checkpoint_path=CHECKPOINT, sample=next(iter(test_loader)), lr=LR)
else:
    model = GeometricNetV2(next(iter(test_loader)), LR)
    if WEIGHTS:
        print(f'Loading model weights  {WEIGHTS}')
        try:
            model.load_state_dict(torch.load(WEIGHTS), strict=True)
            print(f'Model weights {WEIGHTS} loaded')
        except Exception as e:
            print(f'Model weights could not be loaded: {str(e)}')
    else:
        print('Initializing new model')

Initializing new model


In [4]:
def show_weights(model):
    for m in reversed(list(model.modules())):
        if isinstance(m, nn.Linear):
            print('Linear', torch.max(m.weight.grad[0][:10]))
        elif isinstance(m, _fc._Layer):
            print('GraphLinear', torch.max(m.weight.grad))
            break

In [5]:
optimizers, _ = model.configure_optimizers()
optimizer = optimizers[0]

model = model.cuda()
model.train()
for idx, data in tqdm(enumerate(test_loader)):
    if isinstance(data, list):
        for d in data:
            d = d.cuda()
    elif isinstance(data, Data):
        data = data.cuda()
    # run inference and sum up batch loss
    out = model.training_step(data, idx)
    loss = out['loss']

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if idx % 25 == 0:
        clear_output(wait=False)
        print(loss.item())
        show_weights(model)

126it [02:45,  1.10s/it]

8.006976127624512
Linear tensor(0.0073, device='cuda:0')
Linear tensor(0.0049, device='cuda:0')
Linear tensor(8.3500e-06, device='cuda:0')
GraphLinear tensor(9.8832e-05, device='cuda:0')
tensor([0.0428, 0.1266, 0.2303, 0.1650, 0.1317, 0.0642, 0.0488, 0.0424, 0.0275,
        0.0129, 0.0074, 0.0047, 0.0040, 0.0041, 0.0041, 0.0036, 0.0036, 0.0036,
        0.0032, 0.0035, 0.0033, 0.0033, 0.0033, 0.0033, 0.0035, 0.0033, 0.0032,
        0.0032, 0.0033, 0.0032, 0.0034, 0.0034, 0.0031, 0.0034, 0.0035, 0.0034,
        0.0031, 0.0033, 0.0034, 0.0033], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0340, 0.9300, 0.0360, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


127it [02:46,  1.10s/it]

tensor([0.0425, 0.1211, 0.1958, 0.1688, 0.1262, 0.0644, 0.0532, 0.0475, 0.0308,
        0.0144, 0.0087, 0.0059, 0.0051, 0.0050, 0.0052, 0.0045, 0.0045, 0.0044,
        0.0040, 0.0043, 0.0042, 0.0042, 0.0042, 0.0041, 0.0044, 0.0041, 0.0041,
        0.0041, 0.0042, 0.0040, 0.0043, 0.0042, 0.0040, 0.0042, 0.0044, 0.0043,
        0.0039, 0.0042, 0.0042, 0.0042], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.3420, 0.6560, 0.0020, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


128it [02:48,  1.10s/it]

tensor([0.0442, 0.1219, 0.1763, 0.1577, 0.1211, 0.0665, 0.0577, 0.0514, 0.0318,
        0.0152, 0.0095, 0.0067, 0.0058, 0.0059, 0.0063, 0.0054, 0.0051, 0.0051,
        0.0046, 0.0049, 0.0049, 0.0048, 0.0048, 0.0048, 0.0051, 0.0048, 0.0048,
        0.0047, 0.0049, 0.0046, 0.0049, 0.0048, 0.0046, 0.0049, 0.0050, 0.0050,
        0.0046, 0.0049, 0.0049, 0.0049], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0080, 0.8620, 0.1290, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


129it [02:49,  1.10s/it]

tensor([0.0460, 0.1285, 0.1741, 0.1563, 0.1208, 0.0666, 0.0577, 0.0519, 0.0314,
        0.0149, 0.0092, 0.0065, 0.0056, 0.0058, 0.0062, 0.0053, 0.0049, 0.0049,
        0.0045, 0.0047, 0.0048, 0.0047, 0.0047, 0.0046, 0.0050, 0.0047, 0.0047,
        0.0046, 0.0047, 0.0045, 0.0048, 0.0047, 0.0045, 0.0047, 0.0049, 0.0049,
        0.0045, 0.0047, 0.0047, 0.0047], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0010, 0.4520, 0.5460, 0.0010, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


130it [02:50,  1.10s/it]

tensor([0.0491, 0.1446, 0.1808, 0.1574, 0.1237, 0.0672, 0.0561, 0.0502, 0.0287,
        0.0134, 0.0080, 0.0056, 0.0048, 0.0050, 0.0052, 0.0045, 0.0041, 0.0042,
        0.0038, 0.0041, 0.0040, 0.0040, 0.0039, 0.0038, 0.0041, 0.0041, 0.0039,
        0.0039, 0.0040, 0.0039, 0.0040, 0.0040, 0.0039, 0.0040, 0.0042, 0.0041,
        0.0038, 0.0039, 0.0040, 0.0040], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0920,
        0.8960, 0.0120, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


131it [02:51,  1.10s/it]

tensor([0.0519, 0.1448, 0.1586, 0.1379, 0.1206, 0.0719, 0.0624, 0.0543, 0.0307,
        0.0147, 0.0089, 0.0065, 0.0055, 0.0057, 0.0061, 0.0054, 0.0048, 0.0050,
        0.0046, 0.0047, 0.0048, 0.0047, 0.0047, 0.0046, 0.0049, 0.0049, 0.0047,
        0.0047, 0.0048, 0.0047, 0.0047, 0.0048, 0.0047, 0.0047, 0.0049, 0.0049,
        0.0047, 0.0047, 0.0048, 0.0047], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0720, 0.9120, 0.0160, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


132it [02:52,  1.10s/it]

tensor([0.0493, 0.1547, 0.1803, 0.1490, 0.1291, 0.0744, 0.0614, 0.0502, 0.0273,
        0.0124, 0.0071, 0.0049, 0.0042, 0.0043, 0.0044, 0.0039, 0.0035, 0.0037,
        0.0033, 0.0035, 0.0035, 0.0035, 0.0034, 0.0033, 0.0035, 0.0037, 0.0034,
        0.0034, 0.0034, 0.0034, 0.0035, 0.0035, 0.0034, 0.0034, 0.0036, 0.0036,
        0.0034, 0.0034, 0.0035, 0.0034], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.8860, 0.1140, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


133it [02:53,  1.10s/it]

tensor([0.0466, 0.1753, 0.2024, 0.1614, 0.1363, 0.0755, 0.0551, 0.0431, 0.0217,
        0.0094, 0.0051, 0.0034, 0.0028, 0.0029, 0.0027, 0.0025, 0.0022, 0.0024,
        0.0021, 0.0024, 0.0022, 0.0023, 0.0022, 0.0021, 0.0022, 0.0025, 0.0022,
        0.0022, 0.0022, 0.0022, 0.0023, 0.0024, 0.0022, 0.0022, 0.0024, 0.0023,
        0.0022, 0.0021, 0.0023, 0.0022], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0170, 0.9140, 0.0690,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


134it [02:54,  1.10s/it]

tensor([0.0487, 0.1906, 0.2080, 0.1580, 0.1336, 0.0725, 0.0532, 0.0415, 0.0201,
        0.0085, 0.0045, 0.0030, 0.0025, 0.0026, 0.0024, 0.0022, 0.0020, 0.0022,
        0.0019, 0.0021, 0.0019, 0.0020, 0.0019, 0.0018, 0.0019, 0.0022, 0.0019,
        0.0020, 0.0020, 0.0020, 0.0021, 0.0021, 0.0019, 0.0020, 0.0021, 0.0020,
        0.0020, 0.0019, 0.0020, 0.0020], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.2680, 0.7290, 0.0030, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


135it [02:55,  1.10s/it]

tensor([0.0546, 0.2264, 0.2078, 0.1421, 0.1312, 0.0689, 0.0491, 0.0371, 0.0179,
        0.0077, 0.0041, 0.0027, 0.0021, 0.0023, 0.0021, 0.0020, 0.0017, 0.0020,
        0.0017, 0.0019, 0.0017, 0.0018, 0.0017, 0.0016, 0.0017, 0.0020, 0.0017,
        0.0017, 0.0017, 0.0017, 0.0018, 0.0018, 0.0017, 0.0018, 0.0018, 0.0017,
        0.0017, 0.0016, 0.0018, 0.0017], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.1800, 0.8150, 0.0050, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


136it [02:56,  1.10s/it]

tensor([0.0556, 0.2620, 0.2157, 0.1398, 0.1209, 0.0607, 0.0433, 0.0318, 0.0156,
        0.0067, 0.0035, 0.0023, 0.0017, 0.0020, 0.0017, 0.0017, 0.0014, 0.0017,
        0.0014, 0.0016, 0.0014, 0.0015, 0.0014, 0.0013, 0.0014, 0.0016, 0.0014,
        0.0014, 0.0014, 0.0014, 0.0015, 0.0015, 0.0014, 0.0015, 0.0015, 0.0014,
        0.0014, 0.0013, 0.0015, 0.0015], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0160, 0.9130, 0.0710,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


137it [02:57,  1.10s/it]

tensor([0.0528, 0.3021, 0.2459, 0.1282, 0.1011, 0.0494, 0.0378, 0.0280, 0.0132,
        0.0054, 0.0028, 0.0017, 0.0013, 0.0015, 0.0013, 0.0012, 0.0011, 0.0013,
        0.0010, 0.0012, 0.0010, 0.0011, 0.0010, 0.0009, 0.0010, 0.0012, 0.0010,
        0.0011, 0.0011, 0.0011, 0.0011, 0.0012, 0.0010, 0.0011, 0.0011, 0.0011,
        0.0011, 0.0009, 0.0012, 0.0011], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0020, 0.6750, 0.3230, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


138it [02:59,  1.10s/it]

tensor([0.0587, 0.2773, 0.2235, 0.1297, 0.0983, 0.0522, 0.0422, 0.0352, 0.0177,
        0.0076, 0.0040, 0.0026, 0.0020, 0.0023, 0.0022, 0.0020, 0.0018, 0.0021,
        0.0017, 0.0019, 0.0017, 0.0018, 0.0017, 0.0016, 0.0017, 0.0019, 0.0017,
        0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0017, 0.0018, 0.0018, 0.0017,
        0.0018, 0.0016, 0.0019, 0.0018], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.2900, 0.7070, 0.0020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


139it [03:00,  1.10s/it]

tensor([0.0626, 0.1997, 0.1760, 0.1325, 0.0998, 0.0619, 0.0552, 0.0484, 0.0262,
        0.0122, 0.0075, 0.0054, 0.0046, 0.0048, 0.0049, 0.0044, 0.0040, 0.0043,
        0.0038, 0.0039, 0.0039, 0.0040, 0.0038, 0.0037, 0.0039, 0.0040, 0.0038,
        0.0038, 0.0040, 0.0039, 0.0039, 0.0039, 0.0039, 0.0040, 0.0040, 0.0040,
        0.0039, 0.0037, 0.0040, 0.0039], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.7400, 0.2600, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


140it [03:01,  1.10s/it]

tensor([0.0631, 0.1901, 0.1747, 0.1385, 0.1001, 0.0626, 0.0555, 0.0486, 0.0260,
        0.0121, 0.0076, 0.0055, 0.0049, 0.0051, 0.0051, 0.0045, 0.0041, 0.0043,
        0.0038, 0.0040, 0.0040, 0.0040, 0.0039, 0.0038, 0.0040, 0.0041, 0.0039,
        0.0039, 0.0041, 0.0039, 0.0040, 0.0040, 0.0039, 0.0041, 0.0041, 0.0041,
        0.0040, 0.0038, 0.0041, 0.0040], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0110, 0.8900, 0.0990, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


141it [03:02,  1.10s/it]

tensor([0.0594, 0.1817, 0.1802, 0.1459, 0.0990, 0.0625, 0.0537, 0.0490, 0.0268,
        0.0124, 0.0077, 0.0057, 0.0052, 0.0051, 0.0051, 0.0044, 0.0041, 0.0043,
        0.0038, 0.0040, 0.0040, 0.0041, 0.0039, 0.0038, 0.0041, 0.0041, 0.0039,
        0.0039, 0.0041, 0.0039, 0.0040, 0.0040, 0.0039, 0.0040, 0.0041, 0.0041,
        0.0039, 0.0038, 0.0041, 0.0040], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2750,
        0.7220, 0.0030, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


142it [03:03,  1.10s/it]

tensor([0.0559, 0.1727, 0.1714, 0.1512, 0.1030, 0.0667, 0.0538, 0.0488, 0.0274,
        0.0129, 0.0082, 0.0062, 0.0058, 0.0054, 0.0053, 0.0046, 0.0042, 0.0044,
        0.0041, 0.0042, 0.0042, 0.0042, 0.0041, 0.0040, 0.0043, 0.0043, 0.0042,
        0.0041, 0.0042, 0.0041, 0.0042, 0.0042, 0.0041, 0.0042, 0.0043, 0.0043,
        0.0041, 0.0041, 0.0042, 0.0042], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0470, 0.9270, 0.0250, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


143it [03:04,  1.10s/it]

tensor([0.0551, 0.1836, 0.1857, 0.1501, 0.1031, 0.0656, 0.0512, 0.0458, 0.0256,
        0.0119, 0.0077, 0.0060, 0.0056, 0.0049, 0.0046, 0.0040, 0.0039, 0.0041,
        0.0037, 0.0038, 0.0037, 0.0038, 0.0036, 0.0036, 0.0037, 0.0038, 0.0036,
        0.0036, 0.0037, 0.0037, 0.0037, 0.0038, 0.0036, 0.0037, 0.0038, 0.0037,
        0.0037, 0.0036, 0.0038, 0.0037], device='cuda:0',
       grad_fn=<SelectBackward>) tensor([0.0000, 0.0000, 0.0000, 0.3800, 0.6190, 0.0010, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000], device='cuda:0')


144it [03:05,  1.29s/it]
