In [1]:
import torch
import yaml
import ast
#from dig.threedgraph.dataset import QM93D
from dig.threedgraph.dataset import MD17
from dig.threedgraph.dataset.PygTobermorite import Tobermorite
from dig.threedgraph.method import SphereNet #SchNet, DimeNetPP, ComENet
from dig.threedgraph.method import run
from dig.threedgraph.evaluation import ThreeDEvaluator

# Load configs

In [2]:
with open('config.yaml', 'r') as c:
    config = yaml.safe_load(c)
    
# For strings that yaml doesn't parse (e.g. None)
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

In [3]:
name = config['name']
n_train = config['n_train']
n_val = config['n_val']
seed = config['seed']
energy_and_force = config['energy_and_force']
cutoff = config['cutoff']
num_layers = config['num_layers']
hidden_channels = config['hidden_channels']
out_channels = config['out_channels']
int_emb_size = config['int_emb_size']
basis_emb_size_dist = config['basis_emb_size_dist']
basis_emb_size_angle = config['basis_emb_size_angle']
basis_emb_size_torsion = config['basis_emb_size_torsion']
out_emb_channels = config['out_emb_channels']
num_spherical = config['num_spherical']
num_radial = config['num_radial']
envelope_exponent = config['envelope_exponent']
num_before_skip = config['num_before_skip']
num_after_skip = config['num_after_skip']
num_output_layers = config['num_output_layers']

epochs = config['epochs']
batch_size = config['batch_size']
vt_batch_size = config['vt_batch_size']
lr = config['lr']
lr_decay_factor = config['lr_decay_factor']
lr_decay_step_size = config['lr_decay_step_size']

In [4]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

# Dataset

In [5]:
dataset = Tobermorite(root='dataset/', name=name)

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=n_train, valid_size=n_val, seed=seed)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

train, validaion, test: 950 50 6000




# Model

In [6]:
model = SphereNet(energy_and_force=energy_and_force, cutoff=cutoff, num_layers=num_layers, 
        hidden_channels=hidden_channels, out_channels=out_channels, int_emb_size=int_emb_size, 
        basis_emb_size_dist=basis_emb_size_dist, basis_emb_size_angle=basis_emb_size_angle, 
        basis_emb_size_torsion=basis_emb_size_torsion, out_emb_channels=out_emb_channels, 
        num_spherical=num_spherical, num_radial=num_radial, envelope_exponent=envelope_exponent, 
        num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers 
        )
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

# FLOPs

In [7]:
from torch_geometric.data import DataLoader
from tqdm import tqdm
from thop import profile

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
for step, batch_data in enumerate(tqdm(train_loader)):
    print(device)
    batch_data = batch_data.to(device)
    inputs = batch_data
    model = model.to(device)
    flops,params=profile(model,(inputs, ))
    print(flops,params)

  0%|                                                                                                                                             | 0/238 [00:00<?, ?it/s]

cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


  0%|▌                                                                                                                                    | 1/238 [00:04<17:18,  4.38s/it]

2602690688.0 1877952.0
cuda:0


  1%|█                                                                                                                                    | 2/238 [00:04<07:53,  2.01s/it]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2495891712.0 1877952.0
cuda:0


  1%|█▋                                                                                                                                   | 3/238 [00:05<04:50,  1.24s/it]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2615763200.0 1877952.0
cuda:0


  2%|██▏                                                                                                                                  | 4/238 [00:05<03:27,  1.13it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2495546112.0 1877952.0
cuda:0


  2%|██▊                                                                                                                                  | 5/238 [00:05<02:41,  1.45it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564401536.0 1877952.0
cuda:0


  3%|███▎                                                                                                                                 | 6/238 [00:06<02:12,  1.74it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2490491648.0 1877952.0
cuda:0


  3%|███▉                                                                                                                                 | 7/238 [00:06<01:52,  2.04it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2588920960.0 1877952.0
cuda:0


  3%|████▍                                                                                                                                | 8/238 [00:06<01:39,  2.31it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2573774848.0 1877952.0
cuda:0


  4%|█████                                                                                                                                | 9/238 [00:07<01:32,  2.48it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564356352.0 1877952.0
cuda:0


  4%|█████▌                                                                                                                              | 10/238 [00:07<01:26,  2.63it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2585561344.0 1877952.0
cuda:0


  5%|██████                                                                                                                              | 11/238 [00:07<01:23,  2.71it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2584537344.0 1877952.0
cuda:0


  5%|██████▋                                                                                                                             | 12/238 [00:08<01:22,  2.74it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564612352.0 1877952.0
cuda:0


  5%|███████▏                                                                                                                            | 13/238 [00:08<01:20,  2.81it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2620235648.0 1877952.0
cuda:0


  6%|███████▊                                                                                                                            | 14/238 [00:08<01:18,  2.86it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2526738816.0 1877952.0
cuda:0


  6%|████████▎                                                                                                                           | 15/238 [00:09<01:16,  2.92it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2574517248.0 1877952.0
cuda:0


  7%|████████▊                                                                                                                           | 16/238 [00:09<01:15,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2586585344.0 1877952.0
cuda:0


  7%|█████████▍                                                                                                                          | 17/238 [00:09<01:13,  2.99it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2623325696.0 1877952.0
cuda:0


  8%|█████████▉                                                                                                                          | 18/238 [00:10<01:13,  2.99it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2571138816.0 1877952.0
cuda:0


  8%|██████████▌                                                                                                                         | 19/238 [00:10<01:13,  2.98it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2562245120.0 1877952.0
cuda:0


  8%|███████████                                                                                                                         | 20/238 [00:10<01:14,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2629500544.0 1877952.0
cuda:0


  9%|███████████▋                                                                                                                        | 21/238 [00:11<01:13,  2.97it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2597878656.0 1877952.0
cuda:0


  9%|████████████▏                                                                                                                       | 22/238 [00:11<01:11,  3.04it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2583776128.0 1877952.0
cuda:0


 10%|████████████▊                                                                                                                       | 23/238 [00:11<01:10,  3.04it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2578926464.0 1877952.0
cuda:0


 10%|█████████████▎                                                                                                                      | 24/238 [00:12<01:11,  3.01it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2546350592.0 1877952.0
cuda:0


 11%|█████████████▊                                                                                                                      | 25/238 [00:12<01:10,  3.03it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2457147776.0 1877952.0
cuda:0


 11%|██████████████▍                                                                                                                     | 26/238 [00:12<01:11,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2623716480.0 1877952.0
cuda:0


 11%|██████████████▉                                                                                                                     | 27/238 [00:13<01:11,  2.97it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2631861760.0 1877952.0
cuda:0


 12%|███████████████▌                                                                                                                    | 28/238 [00:13<01:10,  2.99it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2545377792.0 1877952.0
cuda:0


 12%|████████████████                                                                                                                    | 29/238 [00:13<01:10,  2.98it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2503614592.0 1877952.0
cuda:0


 13%|████████████████▋                                                                                                                   | 30/238 [00:14<01:11,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2611424768.0 1877952.0
cuda:0


 13%|█████████████████▏                                                                                                                  | 31/238 [00:14<01:11,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2592779008.0 1877952.0
cuda:0


 13%|█████████████████▋                                                                                                                  | 32/238 [00:14<01:11,  2.90it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2547080192.0 1877952.0
cuda:0


 14%|██████████████████▎                                                                                                                 | 33/238 [00:15<01:09,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2575841664.0 1877952.0
cuda:0


 14%|██████████████████▊                                                                                                                 | 34/238 [00:15<01:09,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2610323968.0 1877952.0
cuda:0


 15%|███████████████████▍                                                                                                                | 35/238 [00:15<01:08,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564433152.0 1877952.0
cuda:0


 15%|███████████████████▉                                                                                                                | 36/238 [00:16<01:10,  2.85it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2531684864.0 1877952.0
cuda:0


 16%|████████████████████▌                                                                                                               | 37/238 [00:16<01:08,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2649918720.0 1877952.0
cuda:0


 16%|█████████████████████                                                                                                               | 38/238 [00:16<01:08,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2590872576.0 1877952.0
cuda:0


 16%|█████████████████████▋                                                                                                              | 39/238 [00:17<01:07,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2596490240.0 1877952.0
cuda:0


 17%|██████████████████████▏                                                                                                             | 40/238 [00:17<01:07,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2597219840.0 1877952.0
cuda:0


 17%|██████████████████████▋                                                                                                             | 41/238 [00:17<01:06,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2518382720.0 1877952.0
cuda:0


 18%|███████████████████████▎                                                                                                            | 42/238 [00:18<01:06,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2524281984.0 1877952.0
cuda:0


 18%|███████████████████████▊                                                                                                            | 43/238 [00:18<01:05,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2522714368.0 1877952.0
cuda:0


 18%|████████████████████████▍                                                                                                           | 44/238 [00:18<01:04,  3.02it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2585632128.0 1877952.0
cuda:0


 19%|████████████████████████▉                                                                                                           | 45/238 [00:19<01:03,  3.03it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2635483392.0 1877952.0
cuda:0


 19%|█████████████████████████▌                                                                                                          | 46/238 [00:19<01:05,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2612077568.0 1877952.0
cuda:0


 20%|██████████████████████████                                                                                                          | 47/238 [00:19<01:04,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2541455744.0 1877952.0
cuda:0


 20%|██████████████████████████▌                                                                                                         | 48/238 [00:20<01:04,  2.97it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2543286144.0 1877952.0
cuda:0


 21%|███████████████████████████▏                                                                                                        | 49/238 [00:20<01:04,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2604130304.0 1877952.0
cuda:0


 21%|███████████████████████████▋                                                                                                        | 50/238 [00:20<01:03,  2.98it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564670336.0 1877952.0
cuda:0


 21%|████████████████████████████▎                                                                                                       | 51/238 [00:21<01:02,  2.99it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2640366208.0 1877952.0
cuda:0


 22%|████████████████████████████▊                                                                                                       | 52/238 [00:21<01:03,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2630921344.0 1877952.0
cuda:0


 22%|█████████████████████████████▍                                                                                                      | 53/238 [00:21<01:02,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2585266944.0 1877952.0
cuda:0


 23%|█████████████████████████████▉                                                                                                      | 54/238 [00:22<01:02,  2.92it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2600495872.0 1877952.0
cuda:0


 23%|██████████████████████████████▌                                                                                                     | 55/238 [00:22<01:01,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2503524992.0 1877952.0
cuda:0


 24%|███████████████████████████████                                                                                                     | 56/238 [00:22<01:00,  3.02it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564619136.0 1877952.0
cuda:0


 24%|███████████████████████████████▌                                                                                                    | 57/238 [00:23<00:59,  3.04it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2592094592.0 1877952.0
cuda:0


 24%|████████████████████████████████▏                                                                                                   | 58/238 [00:23<01:00,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2552922240.0 1877952.0
cuda:0


 25%|████████████████████████████████▋                                                                                                   | 59/238 [00:23<01:01,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2510377472.0 1877952.0
cuda:0


 25%|█████████████████████████████████▎                                                                                                  | 60/238 [00:24<01:00,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2591620992.0 1877952.0
cuda:0


 26%|█████████████████████████████████▊                                                                                                  | 61/238 [00:24<00:59,  2.98it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2595601792.0 1877952.0
cuda:0


 26%|██████████████████████████████████▍                                                                                                 | 62/238 [00:24<00:58,  3.00it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2565380352.0 1877952.0
cuda:0


 26%|██████████████████████████████████▉                                                                                                 | 63/238 [00:25<00:59,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2592862592.0 1877952.0
cuda:0


 27%|███████████████████████████████████▍                                                                                                | 64/238 [00:25<01:00,  2.88it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2439827200.0 1877952.0
cuda:0


 27%|████████████████████████████████████                                                                                                | 65/238 [00:25<00:58,  2.98it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2609645568.0 1877952.0
cuda:0


 28%|████████████████████████████████████▌                                                                                               | 66/238 [00:26<00:57,  2.97it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2613299584.0 1877952.0
cuda:0


 28%|█████████████████████████████████████▏                                                                                              | 67/238 [00:26<00:59,  2.88it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2552768640.0 1877952.0
cuda:0


 29%|█████████████████████████████████████▋                                                                                              | 68/238 [00:27<00:59,  2.87it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2555903872.0 1877952.0
cuda:0


 29%|██████████████████████████████████████▎                                                                                             | 69/238 [00:27<00:57,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2551821440.0 1877952.0
cuda:0


 29%|██████████████████████████████████████▊                                                                                             | 70/238 [00:27<00:57,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2549211008.0 1877952.0
cuda:0


 30%|███████████████████████████████████████▍                                                                                            | 71/238 [00:28<00:57,  2.90it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2604264320.0 1877952.0
cuda:0


 30%|███████████████████████████████████████▉                                                                                            | 72/238 [00:28<00:56,  2.92it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2650385536.0 1877952.0
cuda:0


 31%|████████████████████████████████████████▍                                                                                           | 73/238 [00:28<00:56,  2.92it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2653515520.0 1877952.0
cuda:0


 31%|█████████████████████████████████████████                                                                                           | 74/238 [00:29<00:55,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2641818624.0 1877952.0
cuda:0


 32%|█████████████████████████████████████████▌                                                                                          | 75/238 [00:29<00:55,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2747651584.0 1877952.0
cuda:0


 32%|██████████████████████████████████████████▏                                                                                         | 76/238 [00:29<00:55,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2568489984.0 1877952.0
cuda:0


 32%|██████████████████████████████████████████▋                                                                                         | 77/238 [00:30<00:55,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2558392320.0 1877952.0
cuda:0


 33%|███████████████████████████████████████████▎                                                                                        | 78/238 [00:30<00:55,  2.89it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2641120640.0 1877952.0
cuda:0


 33%|███████████████████████████████████████████▊                                                                                        | 79/238 [00:30<00:53,  2.97it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2608692352.0 1877952.0
cuda:0


 34%|████████████████████████████████████████████▎                                                                                       | 80/238 [00:31<00:49,  3.17it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2589893760.0 1877952.0
cuda:0


 34%|████████████████████████████████████████████▉                                                                                       | 81/238 [00:31<00:49,  3.18it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2588536960.0 1877952.0
cuda:0


 34%|█████████████████████████████████████████████▍                                                                                      | 82/238 [00:31<00:49,  3.17it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2486312832.0 1877952.0
cuda:0


 35%|██████████████████████████████████████████████                                                                                      | 83/238 [00:31<00:49,  3.15it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2589611392.0 1877952.0
cuda:0


 35%|██████████████████████████████████████████████▌                                                                                     | 84/238 [00:32<00:49,  3.10it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2539306112.0 1877952.0
cuda:0


 36%|███████████████████████████████████████████████▏                                                                                    | 85/238 [00:32<00:50,  3.05it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2533367680.0 1877952.0
cuda:0


 36%|███████████████████████████████████████████████▋                                                                                    | 86/238 [00:32<00:49,  3.05it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2607963520.0 1877952.0
cuda:0


 37%|████████████████████████████████████████████████▎                                                                                   | 87/238 [00:33<00:48,  3.08it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564727552.0 1877952.0
cuda:0


 37%|████████████████████████████████████████████████▊                                                                                   | 88/238 [00:33<00:47,  3.13it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2668206080.0 1877952.0
cuda:0


 37%|█████████████████████████████████████████████████▎                                                                                  | 89/238 [00:33<00:45,  3.30it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2409235328.0 1877952.0
cuda:0


 38%|█████████████████████████████████████████████████▉                                                                                  | 90/238 [00:34<00:45,  3.24it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2570972416.0 1877952.0
cuda:0


 38%|██████████████████████████████████████████████████▍                                                                                 | 91/238 [00:34<00:46,  3.19it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2623376896.0 1877952.0
cuda:0


 39%|███████████████████████████████████████████████████                                                                                 | 92/238 [00:34<00:46,  3.16it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2579130496.0 1877952.0
cuda:0


 39%|███████████████████████████████████████████████████▌                                                                                | 93/238 [00:35<00:47,  3.04it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2583283712.0 1877952.0
cuda:0


 39%|████████████████████████████████████████████████████▏                                                                               | 94/238 [00:35<00:46,  3.08it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2535082880.0 1877952.0
cuda:0


 40%|████████████████████████████████████████████████████▋                                                                               | 95/238 [00:35<00:45,  3.14it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2537597696.0 1877952.0
cuda:0


 40%|█████████████████████████████████████████████████████▏                                                                              | 96/238 [00:36<00:45,  3.11it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2699309952.0 1877952.0
cuda:0


 41%|█████████████████████████████████████████████████████▊                                                                              | 97/238 [00:36<00:44,  3.18it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2588524160.0 1877952.0
cuda:0


 41%|██████████████████████████████████████████████████████▎                                                                             | 98/238 [00:36<00:42,  3.27it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2552423040.0 1877952.0
cuda:0


 42%|██████████████████████████████████████████████████████▉                                                                             | 99/238 [00:37<00:42,  3.27it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2567178368.0 1877952.0
cuda:0


 42%|███████████████████████████████████████████████████████                                                                            | 100/238 [00:37<00:43,  3.19it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2567907968.0 1877952.0
cuda:0


 42%|███████████████████████████████████████████████████████▌                                                                           | 101/238 [00:37<00:44,  3.08it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2571945216.0 1877952.0
cuda:0


 43%|████████████████████████████████████████████████████████▏                                                                          | 102/238 [00:37<00:43,  3.15it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2678917376.0 1877952.0
cuda:0


 43%|████████████████████████████████████████████████████████▋                                                                          | 103/238 [00:38<00:42,  3.19it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2599241472.0 1877952.0
cuda:0


 44%|█████████████████████████████████████████████████████████▏                                                                         | 104/238 [00:38<00:42,  3.18it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2626154496.0 1877952.0
cuda:0


 44%|█████████████████████████████████████████████████████████▊                                                                         | 105/238 [00:38<00:41,  3.24it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2586112512.0 1877952.0
cuda:0


 45%|██████████████████████████████████████████████████████████▎                                                                        | 106/238 [00:39<00:40,  3.23it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2585254144.0 1877952.0
cuda:0


 45%|██████████████████████████████████████████████████████████▉                                                                        | 107/238 [00:39<00:41,  3.15it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2576801664.0 1877952.0
cuda:0


 45%|███████████████████████████████████████████████████████████▍                                                                       | 108/238 [00:39<00:41,  3.11it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2614752000.0 1877952.0
cuda:0


 46%|███████████████████████████████████████████████████████████▉                                                                       | 109/238 [00:40<00:41,  3.09it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2541174144.0 1877952.0
cuda:0


 46%|████████████████████████████████████████████████████████████▌                                                                      | 110/238 [00:40<00:42,  3.03it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2623665280.0 1877952.0
cuda:0


 47%|█████████████████████████████████████████████████████████████                                                                      | 111/238 [00:40<00:42,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2609549952.0 1877952.0
cuda:0


 47%|█████████████████████████████████████████████████████████████▋                                                                     | 112/238 [00:41<00:41,  3.03it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2612154368.0 1877952.0
cuda:0


 47%|██████████████████████████████████████████████████████████████▏                                                                    | 113/238 [00:41<00:41,  3.04it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2553568256.0 1877952.0
cuda:0


 48%|██████████████████████████████████████████████████████████████▋                                                                    | 114/238 [00:41<00:38,  3.18it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2510384256.0 1877952.0
cuda:0


 48%|███████████████████████████████████████████████████████████████▎                                                                   | 115/238 [00:42<00:38,  3.22it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2634523392.0 1877952.0
cuda:0


 49%|███████████████████████████████████████████████████████████████▊                                                                   | 116/238 [00:42<00:37,  3.22it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2581114496.0 1877952.0
cuda:0


 49%|████████████████████████████████████████████████████████████████▍                                                                  | 117/238 [00:42<00:38,  3.18it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2620658048.0 1877952.0
cuda:0


 50%|████████████████████████████████████████████████████████████████▉                                                                  | 118/238 [00:43<00:38,  3.12it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2649349504.0 1877952.0
cuda:0


 50%|█████████████████████████████████████████████████████████████████▌                                                                 | 119/238 [00:43<00:39,  3.05it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2657469184.0 1877952.0
cuda:0


 50%|██████████████████████████████████████████████████████████████████                                                                 | 120/238 [00:43<00:39,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2597783040.0 1877952.0
cuda:0


 51%|██████████████████████████████████████████████████████████████████▌                                                                | 121/238 [00:44<00:39,  2.99it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2563805952.0 1877952.0
cuda:0


 51%|███████████████████████████████████████████████████████████████████▏                                                               | 122/238 [00:44<00:38,  3.00it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2636865792.0 1877952.0
cuda:0


 52%|███████████████████████████████████████████████████████████████████▋                                                               | 123/238 [00:44<00:39,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2581920896.0 1877952.0
cuda:0


 52%|████████████████████████████████████████████████████████████████████▎                                                              | 124/238 [00:45<00:39,  2.92it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2616632832.0 1877952.0
cuda:0


 53%|████████████████████████████████████████████████████████████████████▊                                                              | 125/238 [00:45<00:38,  2.96it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2531020032.0 1877952.0
cuda:0


 53%|█████████████████████████████████████████████████████████████████████▎                                                             | 126/238 [00:45<00:37,  2.98it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2500735360.0 1877952.0
cuda:0


 53%|█████████████████████████████████████████████████████████████████████▉                                                             | 127/238 [00:46<00:36,  3.05it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2570774400.0 1877952.0
cuda:0


 54%|██████████████████████████████████████████████████████████████████████▍                                                            | 128/238 [00:46<00:36,  3.02it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2578753280.0 1877952.0
cuda:0


 54%|███████████████████████████████████████████████████████████████████████                                                            | 129/238 [00:46<00:37,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2646911488.0 1877952.0
cuda:0


 55%|███████████████████████████████████████████████████████████████████████▌                                                           | 130/238 [00:47<00:36,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2577659264.0 1877952.0
cuda:0


 55%|████████████████████████████████████████████████████████████████████████                                                           | 131/238 [00:47<00:36,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2539562112.0 1877952.0
cuda:0


 55%|████████████████████████████████████████████████████████████████████████▋                                                          | 132/238 [00:47<00:36,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2556774272.0 1877952.0
cuda:0


 56%|█████████████████████████████████████████████████████████████████████████▏                                                         | 133/238 [00:48<00:36,  2.85it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2525145600.0 1877952.0
cuda:0


 56%|█████████████████████████████████████████████████████████████████████████▊                                                         | 134/238 [00:48<00:35,  2.91it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2564043136.0 1877952.0
cuda:0


 57%|██████████████████████████████████████████████████████████████████████████▎                                                        | 135/238 [00:48<00:35,  2.89it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2560293504.0 1877952.0
cuda:0


 57%|██████████████████████████████████████████████████████████████████████████▊                                                        | 136/238 [00:49<00:34,  2.92it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2574152832.0 1877952.0
cuda:0


 58%|███████████████████████████████████████████████████████████████████████████▍                                                       | 137/238 [00:49<00:34,  2.93it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2670708096.0 1877952.0
cuda:0


 58%|███████████████████████████████████████████████████████████████████████████▉                                                       | 138/238 [00:49<00:33,  2.97it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2623114880.0 1877952.0
cuda:0


 58%|████████████████████████████████████████████████████████████████████████████▌                                                      | 139/238 [00:50<00:34,  2.89it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2609325568.0 1877952.0
cuda:0


 59%|█████████████████████████████████████████████████████████████████████████████                                                      | 140/238 [00:50<00:34,  2.85it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2550330624.0 1877952.0
cuda:0


 59%|█████████████████████████████████████████████████████████████████████████████▌                                                     | 141/238 [00:50<00:33,  2.94it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2646949888.0 1877952.0
cuda:0


 60%|██████████████████████████████████████████████████████████████████████████████▏                                                    | 142/238 [00:51<00:32,  2.95it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
2604437504.0 1877952.0
cuda:0


 60%|██████████████████████████████████████████████████████████████████████████████▏                                                    | 142/238 [00:51<00:34,  2.76it/s]


KeyboardInterrupt: 