This notebook shows how to load and use a ResNet pretrained fror MVT Lead-Zinc Deposits

In [1]:
# imports
import torch
from torchinfo import summary
from sri_maper.src.models.cma_module import CMALitModule
import sys
if sys.version_info < (3, 9):
    from importlib_resources import files
else:
    from importlib.resources import files

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# loads the pretrained checkpoint
ckpt_path = str(files("sri_maper.ckpts") / "epoch_007.ckpt")
model = CMALitModule.load_from_checkpoint(ckpt_path)
# prints a model summary
summary(model.net, input_size=(1,73,33,33))

  rank_zero_warn(


Layer (type:depth-idx)                        Output Shape              Param #
ResNet                                        [1, 1]                    --
├─FeatureListNet: 1-1                         [1, 512, 2, 2]            --
│    └─Conv2d: 2-1                            [1, 64, 17, 17]           228,928
│    └─BatchNorm2d: 2-2                       [1, 64, 17, 17]           128
│    └─ReLU: 2-3                              [1, 64, 17, 17]           --
│    └─MaxPool2d: 2-4                         [1, 64, 9, 9]             --
│    └─Sequential: 2-5                        [1, 64, 9, 9]             --
│    │    └─BasicBlock: 3-1                   [1, 64, 9, 9]             73,984
│    │    └─BasicBlock: 3-2                   [1, 64, 9, 9]             73,984
│    └─Sequential: 2-6                        [1, 128, 5, 5]            --
│    │    └─BasicBlock: 3-3                   [1, 128, 5, 5]            230,144
│    │    └─BasicBlock: 3-4                   [1, 128, 5, 5]            295,

In [3]:
# creates random sample data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_patch = torch.rand((2,73,33,33), device=device)
label = torch.randint(low=0, high=1, size=(2,), device=device)
batch_idx = 0
locs = torch.tensor([[-93.19,-93],[32.44,32.44]], device=device) # [list of longs, list of lats]

You can use the CMA module in various ways by calling its methods. 

Please see the documentation within the methods for details. 

Below we provide examples of calling each.

In [4]:
# prints logits
logits = model.forward(input_patch)
print(f"Logits:{model.forward(input_patch)}")

Logits:tensor([[-5.4023],
        [-5.6066]], device='cuda:0', grad_fn=<AddmmBackward0>)


In [5]:
# prints loss, predictions, and labels
loss, pred, target = model.model_step((input_patch, label))
print(f"Loss:{loss}")
print(f"Pred:{pred}")
print(f"Target:{target}")

Loss:0.16410422325134277
Pred:tensor([[0.],
        [0.]], device='cuda:0', dtype=torch.float16)
Target:tensor([0, 0], device='cuda:0')


In [6]:
# prints likelihood, uncertainty, and feature attribution
output_tensor = model.predict_step((input_patch, label, locs[0], locs[1]), batch_idx)
print(f"Long, Lat: {output_tensor[0,:2]}")
print(f"Likelihood, Uncertainty: {output_tensor[0,2:4]}")
print(f"Feature attributions: {output_tensor[0,4:]}")

Long, Lat: tensor([-93.1900,  32.4400], device='cuda:0', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
Likelihood, Uncertainty: tensor([0.0182, 0.0440], device='cuda:0', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
Feature attributions: tensor([ 3.0759e-05, -3.7590e-05, -7.3854e-07,  2.3355e-05, -1.2856e-05,
         2.0724e-06, -1.7814e-05,  1.0744e-05, -1.1759e-05, -1.5473e-05,
        -2.2104e-05, -9.2097e-06, -1.0176e-05, -1.9268e-05,  1.9260e-06,
        -4.1062e-06,  1.8495e-05, -1.6228e-05, -4.6928e-06, -3.3547e-05,
        -3.6940e-05, -2.3548e-05, -6.8299e-05, -1.0001e-05, -1.5050e-05,
        -2.2456e-05, -8.4889e-06,  3.4970e-06, -2.6028e-05, -3.3174e-05,
        -1.2171e-05, -3.2828e-05, -1.1060e-05, -6.8202e-05, -3.2495e-05,
        -1.9297e-06, -7.2174e-06, -1.7573e-05,  8.1906e-06,  1.7460e-05,
        -6.1063e-06, -5.0287e-05, -6.7967e-05,  6.5304e-06, -6.4605e-05,
        -1.3587e-05, -5.1910e-05, -2.9643e-06, -5.9390e-06, -2.4791e-05,
        -8.5277e