In [51]:

import os
import torch
import logging
import torch.nn as nn
import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from enum import Enum
from functools import cached_property
from dataclasses import dataclass
from mubench.base.model import DNN
from mubench.base.dataset import Dataset
from mubench.base.args import Config
from mubench.chemberta.model import ChemBERTa
import random
np.random.seed(0)
random.seed(0)

logger = logging.getLogger(__name__)



In [55]:
config = Config()

In [58]:
model = DNN(128, 2, 1, apply_bbp=True)

In [59]:
model.state_dict().keys()

odict_keys(['input_layer.0.weight', 'input_layer.0.bias', 'hidden_layers.0.0.weight', 'hidden_layers.0.0.bias', 'hidden_layers.1.0.weight', 'hidden_layers.1.0.bias', 'hidden_layers.2.0.weight', 'hidden_layers.2.0.bias', 'hidden_layers.3.0.weight', 'hidden_layers.3.0.bias', 'output_layer.output_layer.weight_mu', 'output_layer.output_layer.weight_rho', 'output_layer.output_layer.bias_mu', 'output_layer.output_layer.bias_rho'])

In [60]:
list(model.named_parameters())

[('input_layer.0.weight',
  Parameter containing:
  tensor([[ 0.0197, -0.0965, -0.0576,  ..., -0.1482, -0.0119,  0.0767],
          [-0.1471, -0.1042, -0.1234,  ..., -0.1399,  0.1250, -0.0574],
          [-0.0682,  0.1503, -0.1108,  ..., -0.0971, -0.0103, -0.0390],
          ...,
          [ 0.0313,  0.0148, -0.0205,  ..., -0.0056, -0.0918, -0.0491],
          [-0.0985, -0.1314,  0.0602,  ..., -0.0949, -0.0166,  0.0448],
          [ 0.1162, -0.0953,  0.1159,  ...,  0.0381,  0.0094,  0.0522]],
         requires_grad=True)),
 ('input_layer.0.bias',
  Parameter containing:
  tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0

In [61]:
params = list(filter(lambda kv: "output_layer" not in kv[0], model.named_parameters()))

In [62]:
params

[('input_layer.0.weight',
  Parameter containing:
  tensor([[ 0.0197, -0.0965, -0.0576,  ..., -0.1482, -0.0119,  0.0767],
          [-0.1471, -0.1042, -0.1234,  ..., -0.1399,  0.1250, -0.0574],
          [-0.0682,  0.1503, -0.1108,  ..., -0.0971, -0.0103, -0.0390],
          ...,
          [ 0.0313,  0.0148, -0.0205,  ..., -0.0056, -0.0918, -0.0491],
          [-0.0985, -0.1314,  0.0602,  ..., -0.0949, -0.0166,  0.0448],
          [ 0.1162, -0.0953,  0.1159,  ...,  0.0381,  0.0094,  0.0522]],
         requires_grad=True)),
 ('input_layer.0.bias',
  Parameter containing:
  tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,
          0.0

In [64]:
type(params[0])

tuple

In [67]:
params[0][1]

Parameter containing:
tensor([[ 0.0197, -0.0965, -0.0576,  ..., -0.1482, -0.0119,  0.0767],
        [-0.1471, -0.1042, -0.1234,  ..., -0.1399,  0.1250, -0.0574],
        [-0.0682,  0.1503, -0.1108,  ..., -0.0971, -0.0103, -0.0390],
        ...,
        [ 0.0313,  0.0148, -0.0205,  ..., -0.0056, -0.0918, -0.0491],
        [-0.0985, -0.1314,  0.0602,  ..., -0.0949, -0.0166,  0.0448],
        [ 0.1162, -0.0953,  0.1159,  ...,  0.0381,  0.0094,  0.0522]],
       requires_grad=True)

In [68]:
type(params[0][1])

torch.nn.parameter.Parameter