-
Notifications
You must be signed in to change notification settings - Fork 370
/
Copy pathbasics.rs
28 lines (26 loc) · 812 Bytes
/
basics.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
use tch::{kind, Tensor};
fn grad_example() {
let mut x = Tensor::from(2.0).set_requires_grad(true);
let y = &x * &x + &x + 36;
println!("{}", y.double_value(&[]));
x.zero_grad();
y.backward();
let dy_over_dx = x.grad();
println!("{}", dy_over_dx.double_value(&[]));
}
fn main() {
tch::maybe_init_cuda();
let t = Tensor::of_slice(&[3, 1, 4, 1, 5]);
t.print();
let t = Tensor::randn(&[5, 4], kind::FLOAT_CPU);
t.print();
(&t + 1.5).print();
(&t + 2.5).print();
let mut t = Tensor::of_slice(&[1.1f32, 2.1, 3.1]);
t += 42;
t.print();
println!("{:?} {}", t.size(), t.double_value(&[1]));
grad_example();
println!("Cuda available: {}", tch::Cuda::is_available());
println!("Cudnn available: {}", tch::Cuda::cudnn_is_available());
}