Skip to content

Should model.state <- dsharp.load(modelFile) change which device the model is on? #427

Open
@nhirschey

Description

@nhirschey

I built the dev branch locally to test the new save/load functionality from PR #425 and I found something unexpected. Should model.state <- dsharp.load(modelFile) change which device the model is on?

When trying to use run rnn.fsx with #r "nuget: TorchSharp-cuda-windows, 0.96.0", restoring model state seems to always move the model to CPU device, which causes an error unless I manually move it back to GPU. Details below.

Related PyTorch docs here.

Edited with simpler self-contained repro:

// Assumes you're in the DiffSharp/examples directory
#I "../tests/DiffSharp.Tests/bin/Debug/net6.0"
#r "DiffSharp.Core.dll"
#r "DiffSharp.Data.dll"
#r "DiffSharp.Backends.Torch.dll"

#r "nuget: TorchSharp-cuda-windows, 0.96.0"


open DiffSharp
open DiffSharp.Compose
open DiffSharp.Model
open DiffSharp.Data
open DiffSharp.Optim
open DiffSharp.Util
open DiffSharp.Distributions

open System.IO

dsharp.config(backend=Backend.Torch, device=Device.GPU)
dsharp.seed(1)

let tempFile = Path.GetTempFileName()

let rnn = RNN(20,20,numLayers=1)

dsharp.save(rnn.state, tempFile)

rnn.device // val it: Device = Device (CUDA, 0)

rnn.state <- dsharp.load(tempFile)

rnn.device // val it: Device = Device (CPU, -1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions