In [1]:
:set -XPackageImports
:set -XOverloadedStrings

## Model Inference

Here we will show how to load a trained model and use it for inference on raw data.

In [2]:
import Prelude hiding (round)
import Graphics.Vega.VegaLite
import qualified Torch as T
import Data.Frame as DF
import "prehsept" Lib

Models trained with `prehsept` can be loaded as torch script modules:

In [3]:
pdk       = GPDK180
dev       = NMOS
timeStamp = "20220902-092231" -- NMOS
-- timeStamp = "20220720-135609" -- PMOS
modelPath = "../models/" ++ show pdk ++ "/" ++ show dev ++ "-" ++ timeStamp ++ "/trace.pt"

In [4]:
model' <- T.loadScript T.WithoutRequiredGrad modelPath

model x = y 
  where
    T.IVTensor y = T.forward model' [T.IVTensor x]



For comparison we will load the training data:

In [5]:
df <- DF.fromFile $ "../data/" ++ show pdk ++"-" ++ show dev ++ ".pt"



We'll choose a random width ($W$) and length ($L$) from the dataset:

In [6]:
(ws, _, _) = T.uniqueDim 0 False False False $ df ?? "W"
(ls, _, _) = T.uniqueDim 0 False False False $ df ?? "L"

In [7]:
choicesW <- T.multinomialIO (T.arange' 0 (head $ T.shape ws) 1) 1 False
w = T.squeezeAll $ T.indexSelect 0 choicesW ws

choicesL <- T.multinomialIO (T.arange' 0 (head $ T.shape ls) 1) 1 False
l = T.squeezeAll $ T.indexSelect 0 choicesL ls

vdd = T.squeezeAll . fst . T.maxDim (T.Dim 0) T.RemoveDim $ df ?? "vds"



In [8]:
vbs = round 2 $ df ?? "vbs"
vgs = round 2 $ df ?? "vgs"
vds = round 2 $ df ?? "vds"

In [9]:
traces = DF.sort False "gmoverid" 
       . DF.lookup [ "gmoverid", "fug", "vds", "vbs"
                   , "id", "W", "L", "gm", "gds", "vgs" ]
       . DF.rowFilter ( T.logicalAnd ((df ?? "W") `T.eq` w)
                      . T.logicalAnd ((df ?? "L") `T.eq` l)
                      . T.logicalAnd (vbs `T.eq` 0.0)
                      $ (vds `T.eq` (vdd / 2.0))) 
       $ df

Next we will get the model inputs from the trace as torch tensor:

In [10]:
x = DF.values . DF.lookup ["gmoverid", "fug", "vds", "vbs"]  $ traces

... and feed it through the model

In [11]:
y = model x

Now we extract the results and compare them with the original data:

In [12]:
trueGmOverId = T.squeezeAll $ traces ?? "gmoverid"
trueIdOverW  = T.squeezeAll $ (traces ?? "id") / (traces ?? "W")
trueGdsOverW = T.squeezeAll $ (traces ?? "gds") / (traces ?? "W")
trueL        = T.squeezeAll $ traces ?? "L"
trueVgs      = T.squeezeAll $ traces ?? "vgs"

predGmOverId = T.squeezeAll $ T.indexSelect' 1 [0] x
predIdOverW  = T.squeezeAll $ T.indexSelect' 1 [0] y
predGdsOverW = T.squeezeAll $ T.indexSelect' 1 [2] y
predL'       = T.squeezeAll $ T.indexSelect' 1 [1] y
predVgs      = T.squeezeAll $ T.indexSelect' 1 [3] y

meanL        = T.mean predL'
stdL         = T.std  predL'
hiL          = T.max $ df ?? "L"
loL          = T.min $ df ?? "L"

meanL'       = T.asValue meanL
predL        <- T.fullLike trueL meanL' T.defaultOpts



In [13]:
plotData :: T.Tensor -> T.Tensor -> T.Tensor -> [(Double, (Double, Double))]
plotData x y y' = zip (T.asValue $ T.toDType T.Double x) 
                $ zip (T.asValue $ T.toDType T.Double y) 
                      (T.asValue $ T.toDType T.Double y')

plt xAx yAx = foldl (\sum' (x, (y, y')) -> sum'
                    . dataRow [ (xAx, Number x)
                              , (yAx, Number y)
                              , ("Lines",   Str "Observation") ] 
                    . dataRow [ (xAx, Number x)
                              , (yAx, Number y')
                              , ("Lines",   Str "Prediction") ]
                    ) (dataFromRows [])

In [14]:
idPlot  = plotData trueGmOverId trueIdOverW  predIdOverW
gdsPlot = plotData trueGmOverId trueGdsOverW predGdsOverW
lPlot   = plotData trueGmOverId trueL        predL

In [15]:
ax1  = PAxis [AxValues (Numbers (map fst idPlot))]
enc1 = encoding
     . position X [ PName "gm/Id in 1/V", PmType Quantitative, ax1 ]
     . position Y [ PName "Id/W in A/m", PmType Quantitative, PScale [SType ScLog]  ]
     . color [ MName "Lines", MmType Nominal ]

ax2  = PAxis [AxValues (Numbers (map fst gdsPlot))]
enc2 = encoding
     . position X [ PName "gm/Id in 1/V", PmType Quantitative, ax2 ]
     . position Y [ PName "gds/W in S/m", PmType Quantitative, PScale [SType ScLog]  ]
     . color [ MName "Lines", MmType Nominal ]

ax3  = PAxis [AxValues (Numbers (map fst lPlot))]
enc3 = encoding
     . position X [ PName "gm/Id in 1/V", PmType Quantitative, ax3 ]
     . position Y [ PName "L in m", PmType Quantitative, PScale [SType ScLog]  ]
     . color [ MName "Lines", MmType Nominal ]

In [16]:
idDat  = plt "gm/Id in 1/V" "Id/W in A/m"  idPlot
gdsDat = plt "gm/Id in 1/V" "gds/W in S/m" gdsPlot
lDat   = plt "gm/Id in 1/V" "L in m"       lPlot

In [17]:
toVegaLite [ idDat [] 
           , mark Line []
           , enc1 []
           , height 200
           , width 300 ]

toVegaLite [ gdsDat [] 
           , mark Line []
           , enc2 []
           , height 200
           , width 300 ]

toVegaLite [ lDat [] 
           , mark Line []
           , enc3 []
           , height 200
           , width 300 ]