In [1]:
import torch
import yaml
import ast
#from dig.threedgraph.dataset import QM93D
from dig.threedgraph.dataset import MD17
from dig.threedgraph.dataset.PygTobermorite import Tobermorite
from dig.threedgraph.method import SphereNet #SchNet, DimeNetPP, ComENet
from dig.threedgraph.method import run
from dig.threedgraph.evaluation import ThreeDEvaluator

# Load configs

In [2]:
with open('config.yaml', 'r') as c:
    config = yaml.safe_load(c)
    
# For strings that yaml doesn't parse (e.g. None)
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

In [3]:
name = config['name']
n_train = config['n_train']
n_val = config['n_val']
seed = config['seed']
energy_and_force = config['energy_and_force']
cutoff = config['cutoff']
num_layers = config['num_layers']
hidden_channels = config['hidden_channels']
out_channels = config['out_channels']
int_emb_size = config['int_emb_size']
basis_emb_size_dist = config['basis_emb_size_dist']
basis_emb_size_angle = config['basis_emb_size_angle']
basis_emb_size_torsion = config['basis_emb_size_torsion']
out_emb_channels = config['out_emb_channels']
num_spherical = config['num_spherical']
num_radial = config['num_radial']
envelope_exponent = config['envelope_exponent']
num_before_skip = config['num_before_skip']
num_after_skip = config['num_after_skip']
num_output_layers = config['num_output_layers']

epochs = config['epochs']
batch_size = config['batch_size']
vt_batch_size = config['vt_batch_size']
lr = config['lr']
lr_decay_factor = config['lr_decay_factor']
lr_decay_step_size = config['lr_decay_step_size']

In [10]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

# Dataset

In [5]:
dataset = Tobermorite(root='dataset/', name=name)

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=n_train, valid_size=n_val, seed=seed)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

train, validaion, test: 950 50 6000




# Model

In [6]:
model = SphereNet(energy_and_force=energy_and_force, cutoff=cutoff, num_layers=num_layers, 
        hidden_channels=hidden_channels, out_channels=out_channels, int_emb_size=int_emb_size, 
        basis_emb_size_dist=basis_emb_size_dist, basis_emb_size_angle=basis_emb_size_angle, 
        basis_emb_size_torsion=basis_emb_size_torsion, out_emb_channels=out_emb_channels, 
        num_spherical=num_spherical, num_radial=num_radial, envelope_exponent=envelope_exponent, 
        num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers 
        )
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

# FLOPs

In [17]:
from torch_geometric.data import DataLoader
from tqdm import tqdm
from thop import profile

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
for step, batch_data in enumerate(tqdm(train_loader)):
    print(device)
    batch_data = batch_data.to(device)
    inputs = batch_data
    model = model.to(device)
    flops,params=profile(model,(inputs, ))
    print(flops,params)

  0%|                                                                                                                                             | 0/238 [00:00<?, ?it/s]

cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


  1%|█▋                                                                                                                                   | 3/238 [00:00<00:55,  4.23it/s]

30319612288.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30383301760.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31390812544.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30620603776.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


  3%|███▉                                                                                                                                 | 7/238 [00:01<00:24,  9.42it/s]

30761649408.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30578639616.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31576157312.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31163682816.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


  5%|██████                                                                                                                              | 11/238 [00:01<00:17, 13.22it/s]

31107071744.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31606167936.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30738311296.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30250420480.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


  6%|████████▎                                                                                                                           | 15/238 [00:01<00:14, 15.60it/s]

30560319104.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30342233600.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31223967488.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30786415744.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


  8%|██████████▌                                                                                                                         | 19/238 [00:01<00:13, 16.67it/s]

30015901440.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29804082560.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30297528576.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31103213568.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 10%|████████████▊                                                                                                                       | 23/238 [00:01<00:12, 17.46it/s]

30743913984.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30052394624.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30659126400.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31134879104.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 11%|██████████████▉                                                                                                                     | 27/238 [00:02<00:11, 17.79it/s]

30887415552.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30622390400.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30146679808.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30257672704.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 13%|█████████████████▏                                                                                                                  | 31/238 [00:02<00:11, 18.00it/s]

31535579264.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30542936704.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31825196928.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30952633600.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 15%|███████████████████▍                                                                                                                | 35/238 [00:02<00:11, 18.13it/s]

30328793600.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30963338112.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30790895744.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31482140928.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 16%|█████████████████████▋                                                                                                              | 39/238 [00:02<00:10, 18.15it/s]

31011400448.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31007147136.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30675786624.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30993933824.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 18%|███████████████████████▊                                                                                                            | 43/238 [00:03<00:10, 18.12it/s]

30368075136.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31163603968.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30881501952.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31265441536.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 20%|██████████████████████████                                                                                                          | 47/238 [00:03<00:10, 18.18it/s]

31829349888.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30925437312.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31152983680.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30514138368.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 21%|████████████████████████████▎                                                                                                       | 51/238 [00:03<00:10, 18.13it/s]

30872958592.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31361070720.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30872326016.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31454143616.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 23%|██████████████████████████████▌                                                                                                     | 55/238 [00:03<00:10, 18.17it/s]

30962848000.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30336093312.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30637179776.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29910679680.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 25%|████████████████████████████████▋                                                                                                   | 59/238 [00:03<00:09, 18.13it/s]

30304696576.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30505984768.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30664718336.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31014578560.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 26%|██████████████████████████████████▉                                                                                                 | 63/238 [00:04<00:09, 18.11it/s]

30078194048.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30656707200.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31133946368.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30788972032.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 28%|█████████████████████████████████████▏                                                                                              | 67/238 [00:04<00:09, 18.14it/s]

30091992448.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30893824640.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30895927552.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29939615104.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 30%|███████████████████████████████████████▍                                                                                            | 71/238 [00:04<00:09, 18.19it/s]

31414103168.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30495058944.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30678110848.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31071637632.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 32%|█████████████████████████████████████████▌                                                                                          | 75/238 [00:04<00:08, 18.24it/s]

30802090368.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31587489024.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30211987456.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31275924736.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 33%|███████████████████████████████████████████▊                                                                                        | 79/238 [00:05<00:08, 18.05it/s]

30741584384.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30461696384.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31336346496.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30570707328.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 35%|██████████████████████████████████████████████                                                                                      | 83/238 [00:05<00:08, 18.13it/s]

31139454080.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30749463808.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31007942784.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31123916544.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 37%|████████████████████████████████████████████████▎                                                                                   | 87/238 [00:05<00:08, 18.19it/s]

30986760448.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31246045824.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31109353856.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30648016000.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 38%|██████████████████████████████████████████████████▍                                                                                 | 91/238 [00:05<00:08, 18.20it/s]

30939293952.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30812889856.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29549266432.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31888838912.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 40%|████████████████████████████████████████████████████▋                                                                               | 95/238 [00:05<00:07, 18.18it/s]

30490710656.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30373941248.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30573126528.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30374442112.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 42%|██████████████████████████████████████████████████████▉                                                                             | 99/238 [00:06<00:07, 18.14it/s]

29994613376.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30346576512.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30389125760.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30079812224.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 43%|████████████████████████████████████████████████████████▋                                                                          | 103/238 [00:06<00:07, 18.12it/s]

31087317632.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30340172800.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30663374336.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30141119232.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 45%|██████████████████████████████████████████████████████████▉                                                                        | 107/238 [00:06<00:07, 18.01it/s]

30256012416.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31875314688.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31024028672.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29514691584.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 47%|█████████████████████████████████████████████████████████████                                                                      | 111/238 [00:06<00:07, 17.94it/s]

31791865728.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31195174528.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31161532416.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30984214912.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 48%|███████████████████████████████████████████████████████████████▎                                                                   | 115/238 [00:07<00:06, 18.05it/s]

30680577536.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31073160832.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30588400640.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30641475200.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 50%|█████████████████████████████████████████████████████████████████▌                                                                 | 119/238 [00:07<00:06, 18.15it/s]

31008628224.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30346576512.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31633801472.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30292421376.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 52%|███████████████████████████████████████████████████████████████████▋                                                               | 123/238 [00:07<00:06, 18.17it/s]

30799054720.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30093652736.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30557583616.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30997912960.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 53%|█████████████████████████████████████████████████████████████████████▉                                                             | 127/238 [00:07<00:06, 18.16it/s]

30763214720.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30065512960.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30869147904.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31164220416.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 55%|████████████████████████████████████████████████████████████████████████                                                           | 131/238 [00:07<00:05, 18.20it/s]

31295541760.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31380424320.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30493672832.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30769033344.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 57%|██████████████████████████████████████████████████████████████████████████▎                                                        | 135/238 [00:08<00:05, 18.22it/s]

30779037184.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30410988160.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31201578240.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29611121792.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 58%|████████████████████████████████████████████████████████████████████████████▌                                                      | 139/238 [00:08<00:05, 18.20it/s]

31148545792.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29850437120.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30171535744.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31515735552.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 60%|██████████████████████████████████████████████████████████████████████████████▋                                                    | 143/238 [00:08<00:05, 18.04it/s]

31159914240.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30623518464.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30321088000.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30633369088.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 62%|████████████████████████████████████████████████████████████████████████████████▉                                                  | 147/238 [00:08<00:05, 18.16it/s]

30722720896.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31396104320.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30845498880.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30758197120.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 63%|███████████████████████████████████████████████████████████████████████████████████                                                | 151/238 [00:09<00:04, 18.15it/s]

29803897984.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30793894656.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29853304320.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29394970752.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 65%|█████████████████████████████████████████████████████████████████████████████████████▎                                             | 155/238 [00:09<00:04, 18.20it/s]

31051488384.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30251532416.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31162117504.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30941570688.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 67%|███████████████████████████████████████████████████████████████████████████████████████▌                                           | 159/238 [00:09<00:04, 18.17it/s]

31407789056.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
32123047040.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30196038656.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30019437952.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 68%|█████████████████████████████████████████████████████████████████████████████████████████▋                                         | 163/238 [00:09<00:04, 18.14it/s]

30848671616.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30337980288.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30655942912.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31016771072.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 70%|███████████████████████████████████████████████████████████████████████████████████████████▉                                       | 167/238 [00:09<00:03, 18.20it/s]

30101938048.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31015564160.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31253572224.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31211434240.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 72%|██████████████████████████████████████████████████████████████████████████████████████████████                                     | 171/238 [00:10<00:03, 18.19it/s]

31124222080.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30473512832.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29932041216.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30324540288.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 74%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                  | 175/238 [00:10<00:03, 18.11it/s]

30346486912.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29531536384.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29912877568.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30691956736.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▌                                | 179/238 [00:10<00:03, 18.16it/s]

30176326656.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30077477248.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30910311040.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30983271424.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 77%|████████████████████████████████████████████████████████████████████████████████████████████████████▋                              | 183/238 [00:10<00:03, 18.12it/s]

30562464128.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31055557120.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30844160256.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30895073664.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▊                             | 185/238 [00:10<00:02, 18.08it/s]

31363173632.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30334438400.0 1877952.0
cuda:0


 79%|████████████████████████████████████████████████████████████████████████████████████████████████████████                           | 189/238 [00:11<00:03, 14.23it/s]

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30394638848.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31149752704.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30610531840.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 193/238 [00:11<00:02, 16.03it/s]

31423374080.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30756547584.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30531246592.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30614785152.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 197/238 [00:11<00:02, 17.00it/s]

30886208640.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30705170048.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30838104192.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31625552896.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                    | 201/238 [00:11<00:02, 17.44it/s]

30890820352.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30963833600.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30486641920.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30893376640.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                  | 205/238 [00:12<00:01, 17.71it/s]

30878592640.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30485566720.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30167725056.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30957424512.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████                | 209/238 [00:12<00:01, 17.93it/s]

30753053184.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30607880576.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29957756416.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30902241664.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏             | 213/238 [00:12<00:01, 18.04it/s]

30374305024.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30625837312.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29608170368.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31070968320.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 217/238 [00:12<00:01, 18.09it/s]

30808857856.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29865263232.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30163340032.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30995188224.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋         | 221/238 [00:13<00:00, 18.08it/s]

31425350656.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31583820800.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30482335744.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30545577216.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊       | 225/238 [00:13<00:00, 18.11it/s]

30641786112.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30038064000.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30952538624.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
31135511680.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████     | 229/238 [00:13<00:00, 18.20it/s]

30904302464.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30325979264.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30553772928.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30976235136.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏  | 233/238 [00:13<00:00, 18.12it/s]

30666563200.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30400952960.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30827942656.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30896955264.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍| 237/238 [00:13<00:00, 18.09it/s]

30352442624.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
29014520192.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30612455552.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
30208176768.0 1877952.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 238/238 [00:13<00:00, 17.02it/s]

15497767936.0 1877952.0



