Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should model.state <- dsharp.load(modelFile) change which device the model is on? #427

Open
nhirschey opened this issue Apr 26, 2022 · 0 comments

Comments

@nhirschey
Copy link
Contributor

nhirschey commented Apr 26, 2022

I built the dev branch locally to test the new save/load functionality from PR #425 and I found something unexpected. Should model.state <- dsharp.load(modelFile) change which device the model is on?

When trying to use run rnn.fsx with #r "nuget: TorchSharp-cuda-windows, 0.96.0", restoring model state seems to always move the model to CPU device, which causes an error unless I manually move it back to GPU. Details below.

Related PyTorch docs here.

Edited with simpler self-contained repro:

// Assumes you're in the DiffSharp/examples directory
#I "../tests/DiffSharp.Tests/bin/Debug/net6.0"
#r "DiffSharp.Core.dll"
#r "DiffSharp.Data.dll"
#r "DiffSharp.Backends.Torch.dll"

#r "nuget: TorchSharp-cuda-windows, 0.96.0"


open DiffSharp
open DiffSharp.Compose
open DiffSharp.Model
open DiffSharp.Data
open DiffSharp.Optim
open DiffSharp.Util
open DiffSharp.Distributions

open System.IO

dsharp.config(backend=Backend.Torch, device=Device.GPU)
dsharp.seed(1)

let tempFile = Path.GetTempFileName()

let rnn = RNN(20,20,numLayers=1)

dsharp.save(rnn.state, tempFile)

rnn.device // val it: Device = Device (CUDA, 0)

rnn.state <- dsharp.load(tempFile)

rnn.device // val it: Device = Device (CPU, -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant