Skip to content

Commit

Permalink
Fix VarType
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Feb 5, 2024
1 parent ca68a91 commit e4d6d21
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
8 changes: 6 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@ def _complex_to_real_dtype(dtype):
return core.VarDesc.VarType.FP32
elif dtype == core.VarDesc.VarType.COMPLEX128:
return core.VarDesc.VarType.FP64
elif dtype == paddle.pir.core.DataType.COMPLEX64:
return paddle.pir.core.DataType.FP32
elif dtype == paddle.pir.core.DataType.COMPLEX128:
return paddle.pir.core.DataType.FP64
else:
return dtype


def _real_to_complex_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
if dtype == paddle.float32:
return core.VarDesc.VarType.COMPLEX64
elif dtype == core.VarDesc.VarType.FP64:
elif dtype == paddle.float64:
return core.VarDesc.VarType.COMPLEX128
else:
return dtype
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/tensor/to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

import paddle
from paddle.base.data_feeder import check_type, convert_dtype

from ..framework import core
Expand Down Expand Up @@ -238,7 +239,7 @@ def to_string(var, prefix='Tensor'):
indent = len(prefix) + 1

dtype = convert_dtype(var.dtype)
if var.dtype == core.VarDesc.VarType.BF16:
if var.dtype == paddle.bfloat16:
dtype = 'bfloat16'

_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
Expand All @@ -247,7 +248,7 @@ def to_string(var, prefix='Tensor'):
if not tensor._is_initialized():
return "Tensor(Not initialized)"

if var.dtype == core.VarDesc.VarType.BF16:
if var.dtype == paddle.bfloat16:
var = var.astype('float32')
np_var = var.numpy(False)

Expand Down Expand Up @@ -280,7 +281,7 @@ def to_string(var, prefix='Tensor'):


def _format_dense_tensor(tensor, indent):
if tensor.dtype == core.VarDesc.VarType.BF16:
if tensor.dtype == paddle.bfloat16:
tensor = tensor.astype('float32')

# TODO(zhouwei): will remove 0-D Tensor.numpy() hack
Expand Down Expand Up @@ -360,7 +361,7 @@ def dist_tensor_to_string(tensor, prefix='Tensor'):
# is ready.
indent = len(prefix) + 1
dtype = convert_dtype(tensor.dtype)
if tensor.dtype == core.VarDesc.VarType.BF16:
if tensor.dtype == paddle.bfloat16:
dtype = 'bfloat16'

if not tensor._is_dense_tensor_hold_allocation():
Expand Down Expand Up @@ -395,7 +396,7 @@ def tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1

dtype = convert_dtype(tensor.dtype)
if tensor.dtype == core.VarDesc.VarType.BF16:
if tensor.dtype == paddle.bfloat16:
dtype = 'bfloat16'

_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
Expand Down

0 comments on commit e4d6d21

Please sign in to comment.