In [1]:
import torch

# Create a random tensor of size (2, 4)
tensor = torch.randn(2, 4)
print("Original tensor:")
print(tensor)
print(f"Shape: {tensor.shape}")

# Select the last 3 columns
last_3_columns = tensor[:, -3:]
print("\nLast 3 columns:")
print(last_3_columns)
print(f"Shape: {last_3_columns.shape}")

Original tensor:
tensor([[-0.8148, -1.6768,  1.5075, -1.6973],
        [ 0.9960,  0.3291, -0.6550,  0.7242]])
Shape: torch.Size([2, 4])

Last 3 columns:
tensor([[-1.6768,  1.5075, -1.6973],
        [ 0.3291, -0.6550,  0.7242]])
Shape: torch.Size([2, 3])


In [3]:
tensor[:, -2]

tensor([ 1.5075, -0.6550])

In [4]:
logits = torch.tensor([1,4,50257])

In [5]:
print(logits)

tensor([    1,     4, 50257])


In [6]:
torch.argmax(logits, dim=-1)

tensor(2)

In [7]:
# Create a tensor with shape [1, 4, 50257]
logits = torch.randn(1, 4, 50257)
print(f"Original shape: {logits.shape}")

# torch.argmax with different dim parameters
result_dim_neg1 = torch.argmax(logits, dim=-1)  # Same as dim=2
print(f"torch.argmax(logits, dim=-1) shape: {result_dim_neg1.shape}")

result_dim_0 = torch.argmax(logits, dim=0)
print(f"torch.argmax(logits, dim=0) shape: {result_dim_0.shape}")

result_dim_1 = torch.argmax(logits, dim=1)
print(f"torch.argmax(logits, dim=1) shape: {result_dim_1.shape}")

result_dim_2 = torch.argmax(logits, dim=2)
print(f"torch.argmax(logits, dim=2) shape: {result_dim_2.shape}")

# Without specifying dim (flattens and finds global argmax)
result_no_dim = torch.argmax(logits)
print(f"torch.argmax(logits) shape: {result_no_dim.shape}")

Original shape: torch.Size([1, 4, 50257])
torch.argmax(logits, dim=-1) shape: torch.Size([1, 4])
torch.argmax(logits, dim=0) shape: torch.Size([4, 50257])
torch.argmax(logits, dim=1) shape: torch.Size([1, 50257])
torch.argmax(logits, dim=2) shape: torch.Size([1, 4])
torch.argmax(logits) shape: torch.Size([])


In [10]:
# Create a simple tensor to illustrate
x = torch.tensor([10.5, 3.2, 15.8, 7.1])
print(f"Original tensor: {x}")

# torch.argmax returns the INDEX of the max value
max_index = torch.argmax(x)
print(f"torch.argmax(x): {max_index}")  # Returns 2 (the index)
print(f"Type: {type(max_index.item())}")  # It's an integer index

# To get the actual MAX VALUE, use torch.max
max_value = torch.max(x)
print(f"torch.max(x): {max_value}")  # Returns 15.8 (the actual value)

# You can also use the index to get the value
actual_max_value = x[max_index]
print(f"x[max_index]: {actual_max_value}")  # Also returns 15.8

# For 2D example
y = torch.tensor([[1.0, 5.0, 3.0], 
                  [2.0, 1.0, 4.0]])
print(f"\n2D tensor:\n{y}")
print(f"argmax along dim=-1: {torch.argmax(y, dim=-1)}")  # [1, 2] - indices
print(f"max along dim=-1: {torch.max(y, dim=-1)}")       # values and indices

Original tensor: tensor([10.5000,  3.2000, 15.8000,  7.1000])
torch.argmax(x): 2
Type: <class 'int'>
torch.max(x): 15.800000190734863
x[max_index]: 15.800000190734863

2D tensor:
tensor([[1., 5., 3.],
        [2., 1., 4.]])
argmax along dim=-1: tensor([1, 2])
max along dim=-1: torch.return_types.max(
values=tensor([5., 4.]),
indices=tensor([1, 2]))


In [12]:
torch.max(y, dim=-1)

torch.return_types.max(
values=tensor([5., 4.]),
indices=tensor([1, 2]))