### Try to open models you have saved, and then look at them
Takeaway: You need to have the model class's code (including all hyperparameters)
If you save a Torchscript version using torch.jit.script(model) you do not need to class, but can only use for inference (see saving + loading models pytorch tutorial)

In [19]:
import torch
from torch import nn
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

In [13]:
model_file = "/tigress/kendrab/analysis-notebooks/model_outs/05-07-23/samples/A135332_modelfile.tar"

In [14]:
# hyperparameters 
padding_length = 10  # amount of data on each side of each segment for additional info
stride = 10  # size (and therefore spacing) of each segment
input_length = stride + 2*padding_length
kernel_size = 3
pool_size = 2
out_channels = 16  # like 'filters' in keras
learning_rate = 0.01

In [15]:
#TODO feed hyperparameters into __init__
class ModelA(nn.Module):
    """ 1D CNN Model """
    def __init__(self):
        super().__init__()
        # define these all separately because they will get different weights
        # consider smooshing these together into one convolution with in_channels=5
        self.bx_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.ReLU(),
                                       nn.MaxPool1d(pool_size))
        self.by_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.ReLU(),
                                       nn.MaxPool1d(pool_size))
        self.bz_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.ReLU(),
                                       nn.MaxPool1d(pool_size))
        self.jy_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.ReLU(),
                                       nn.MaxPool1d(pool_size))
        self.vz_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.ReLU(),
                                       nn.MaxPool1d(pool_size))
        # TODO split this into CNN and classifier parts to facilitate domain adaptation
        self.post_merge_layers = nn.Sequential(nn.Conv1d(out_channels, out_channels*2, kernel_size,
                                                         padding='valid'),
                                               nn.ReLU(),
                                               nn.MaxPool1d(pool_size),
                                               nn.Flatten(),
                                               nn.LazyLinear(stride*2),
                                               nn.ReLU(),
                                               nn.Unflatten(1,(2,stride)))
                                               

    def forward(self, bx, by, bz, jy, vz):
        bx_proc = self.bx_layers(bx)
        by_proc = self.by_layers(by)
        bz_proc = self.bz_layers(bz)
        jy_proc = self.jy_layers(jy)
        vz_proc = self.vz_layers(vz)
        combined = .2*(bx_proc + by_proc + bz_proc + jy_proc + vz_proc)
        logits = self.post_merge_layers(combined)
        
        return logits


In [17]:
model = ModelA()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

checkpoint = torch.load(model_file)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss_fn = checkpoint['loss_fn']

model.eval()  # set to correct mode to get the correct results

ModelA(
  (bx_layers): Sequential(
    (0): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=valid)
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (by_layers): Sequential(
    (0): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=valid)
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (bz_layers): Sequential(
    (0): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=valid)
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (jy_layers): Sequential(
    (0): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=valid)
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (vz_layers): Sequential(
    (0): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=valid)
    (1): ReLU()
    (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fal

In [18]:
get_graph_node_names(model)

(['bx',
  'by',
  'bz',
  'jy',
  'vz',
  'bx_layers.0',
  'bx_layers.1',
  'bx_layers.2',
  'by_layers.0',
  'by_layers.1',
  'by_layers.2',
  'bz_layers.0',
  'bz_layers.1',
  'bz_layers.2',
  'jy_layers.0',
  'jy_layers.1',
  'jy_layers.2',
  'vz_layers.0',
  'vz_layers.1',
  'vz_layers.2',
  'add',
  'add_1',
  'add_2',
  'add_3',
  'mul',
  'post_merge_layers.0',
  'post_merge_layers.1',
  'post_merge_layers.2',
  'post_merge_layers.3',
  'post_merge_layers.4',
  'post_merge_layers.5',
  'post_merge_layers.6'],
 ['bx',
  'by',
  'bz',
  'jy',
  'vz',
  'bx_layers.0',
  'bx_layers.1',
  'bx_layers.2',
  'by_layers.0',
  'by_layers.1',
  'by_layers.2',
  'bz_layers.0',
  'bz_layers.1',
  'bz_layers.2',
  'jy_layers.0',
  'jy_layers.1',
  'jy_layers.2',
  'vz_layers.0',
  'vz_layers.1',
  'vz_layers.2',
  'add',
  'add_1',
  'add_2',
  'add_3',
  'mul',
  'post_merge_layers.0',
  'post_merge_layers.1',
  'post_merge_layers.2',
  'post_merge_layers.3',
  'post_merge_layers.4',
  'post

### try to use the feature extractor thing

In [21]:
model_feats = create_feature_extractor(model, return_nodes=["post_merge_layers.3",])
get_graph_node_names(model_feats)

(['bx',
  'by',
  'bz',
  'jy',
  'vz',
  'bx_layers.0',
  'bx_layers.1',
  'bx_layers.2',
  'by_layers.0',
  'by_layers.1',
  'by_layers.2',
  'bz_layers.0',
  'bz_layers.1',
  'bz_layers.2',
  'jy_layers.0',
  'jy_layers.1',
  'jy_layers.2',
  'vz_layers.0',
  'vz_layers.1',
  'vz_layers.2',
  'add',
  'add_1',
  'add_2',
  'add_3',
  'mul',
  'post_merge_layers.0',
  'post_merge_layers.1',
  'post_merge_layers.2',
  'post_merge_layers.3'],
 ['bx',
  'by',
  'bz',
  'jy',
  'vz',
  'bx_layers.0',
  'bx_layers.1',
  'bx_layers.2',
  'by_layers.0',
  'by_layers.1',
  'by_layers.2',
  'bz_layers.0',
  'bz_layers.1',
  'bz_layers.2',
  'jy_layers.0',
  'jy_layers.1',
  'jy_layers.2',
  'vz_layers.0',
  'vz_layers.1',
  'vz_layers.2',
  'add',
  'add_1',
  'add_2',
  'add_3',
  'mul',
  'post_merge_layers.0',
  'post_merge_layers.1',
  'post_merge_layers.2',
  'post_merge_layers.3'])