In [1]:
import os
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import math
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import BinaryAUROC, BinaryAUPRC

os.chdir('../..')
from src.raindrop.raindrop import Raindrop
from src.raindrop.classifier import RaindropClassifier
from src.util.grad_track import GradientTracker, GradientFlowAnalyzer, pretty_flow
from src.p19.utils import *

In [4]:
def load_latest_model(models_dir: Path = Path('./models')) -> nn.Module:
    assert models_dir.exists()

    def model_name_key(name: str):
        return math.prod([int(v) for v in name.split('_')[1:]])

    recent_model_name = sorted(next(models_dir.walk())[2],
                               key=model_name_key)[-1]
        
    return torch.load(models_dir / recent_model_name, weights_only=True)

rd_cls = load_latest_model()
rd_cls

OrderedDict([('rd_model.obs_emb_weights',
              tensor([[[-1.0803e+00, -8.0069e-01,  2.4556e-01, -9.9541e-01]],
              
                      [[-1.4734e-01, -4.4649e-04,  5.9392e-01, -7.0453e-02]],
              
                      [[-7.2876e-01, -5.9000e-01,  1.0194e+00,  1.6161e-01]],
              
                      [[ 9.2568e-01,  5.1473e-01, -1.0914e+00, -9.7570e-01]],
              
                      [[-1.1516e+00,  4.7278e-01, -6.1093e-01,  1.0537e-01]],
              
                      [[ 7.6890e-01, -1.0950e+00, -1.1399e+00,  6.6586e-01]],
              
                      [[-1.6712e-01,  6.7675e-01, -4.5569e-01,  1.2524e-01]],
              
                      [[ 8.4744e-01, -7.2125e-01, -1.0334e+00, -2.0182e-01]],
              
                      [[ 4.2203e-01,  1.0669e+00, -5.7220e-02,  7.4093e-01]],
              
                      [[-1.1235e+00, -2.4892e-01, -5.7873e-01,  1.1228e+00]],
              
                      [[-4.6