-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Phase 1: Expanded Core Functionality
Tensor Operations
torch_cat(tensors, n_tensors, dim)
torch_stack(tensors, n_tensors, dim)
torch_split(tensor, split_size, dim)
torch_transpose(tensor, dim0, dim1)
torch_permute(tensor, dims)
torch_flatten(tensor, start_dim, end_dim)
torch_dot(a, b) / torch_mm / torch_bmm
torch_einsum(equation, tensors, n_tensors)
Indexing & Slicing
torch_tensor_index(tensor, indices)
torch_tensor_index_put(tensor, indices, values)
torch_tensor_slice(tensor, dim, start, end, step)
torch_tensor_masked_select(tensor, mask)
torch_tensor_gather(tensor, dim, index)
More NN Layers
torch_nn_conv2d_create(in_ch, out_ch, kernel, stride, padding)
torch_nn_batchnorm2d_create(num_features)
torch_nn_dropout_create(p)
torch_nn_embedding_create(num_embeddings, embedding_dim)
torch_nn_lstm_create(input_size, hidden_size, num_layers)
torch_nn_transformer_create(d_model, nhead, num_layers)
Activation Functions
torch_gelu / torch_silu / torch_mish
torch_leaky_relu(tensor, negative_slope)
torch_elu / torch_selu / torch_prelu
torch_softplus / torch_softsign
Loss Functions
torch_nn_l1_loss
torch_nn_smooth_l1_loss
torch_nn_cosine_embedding_loss
torch_nn_ctc_loss
torch_nn_kl_div_loss
Optimizers
torch_optim_adamw_create
torch_optim_rmsprop_create
torch_optim_sgd_with_nesterov
torch_optim_lr_scheduler_step_create
torch_optim_get_lr(optimizer)
Random & Init
torch_manual_seed(seed)
torch_rand/randn_like(tensor)
torch_normal(mean, std, shape)
torch_nn_init_xavier_uniform(tensor)
torch_nn_init_kaiming_normal(tensor)
Comparison & Logic
torch_eq/ne/lt/le/gt/ge(a, b)
torch_logical_and/or/not(a, b)
torch_where(condition, x, y)
torch_isnan/isinf/isfinite(tensor)
Reduction Ops
torch_min/max(tensor, dim, keepdim)
torch_argmin/argmax(tensor, dim)
torch_std/var/median(tensor, dim)
torch_prod(tensor, dim)
torch_cumsum/cumprod(tensor, dim)
CUDA Specific
torch_cuda_synchronize()
torch_cuda_empty_cache()
torch_cuda_memory_allocated(device)
torch_cuda_max_memory_allocated(device)