In [None]:
# In this notebook, you learn:
#
# 1) How torch.unique() works?
# 2) How torch.unbind() works?

In [2]:
import torch

## [torch.unique](https://pytorch.org/docs/stable/generated/torch.unique.html)

In [3]:
# This helps us get unique elements in a tensor. 

Lets see how to apply torch.unique on a 1D tensor

In [5]:
t1 = torch.tensor(data=[1, 2, 5, 2, 3, 1, 6, 7], dtype=torch.int32)
print(f"shape of t1: {t1.shape}")
print(f"t1: {t1}")

shape of t1: torch.Size([8])
t1: tensor([1, 2, 5, 2, 3, 1, 6, 7], dtype=torch.int32)


In [None]:
# Every tensor along the given dimension will be considered an element and the uniqueness is calculated for these 
# tensors as a whole. Here the elements along dimension 0 are 1, 2, 5, 2, 3, 1, 6, 7. The unique elements are 
# 1, 2, 5, 3, 6, 7. So the output will have 6 elements and they are sorted by default.
t2 = torch.unique(input=t1, dim=0)
print(f"shape of t2: {t2.shape}")
print(f"t2: {t2}")

shape of t2: torch.Size([6])
t2: tensor([1, 2, 3, 5, 6, 7], dtype=torch.int32)


In [None]:
# return_counts=True will also return the count of each unique element i.e., the number of times
# each unique element appears in the input tensor.
# Unique elements -- [1, 2, 3, 5, 6, 7]
# Count of each unique element -- [2, 2, 1, 1, 1, 1]
# 1 --> 2 times
# 2 --> 2 times
# 3 --> 1 time
# 5 --> 1 time
# 6 --> 1 time
# 7 --> 1 time
t3_unique_elements, t3_counts = torch.unique(input=t1, dim=0, return_counts=True)
print(f"shape of t3: {t3_unique_elements.shape}")
print(f"t3: {t3_unique_elements}")
print("-" * 150)
print(f"shape of t3_counts: {t3_counts.shape}")
print(f"t3_counts: {t3_counts}")

shape of t3: torch.Size([6])
t3: tensor([1, 2, 3, 5, 6, 7], dtype=torch.int32)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of t3_counts: torch.Size([6])
t3_counts: tensor([2, 2, 1, 1, 1, 1])


Lets see how to apply torch.unique on a 2D tensor

In [19]:
t4 = torch.tensor(data=[[1, 2, 3, 2], [4, 5, 6, 10], [1, 2, 3, 2], [7, 8, 8, 9], [4, 5, 6, 10]], dtype=torch.int32)
print(f"shape of t4: {t4.shape}")
print(f"t4: {t4}")

shape of t4: torch.Size([5, 4])
t4: tensor([[ 1,  2,  3,  2],
        [ 4,  5,  6, 10],
        [ 1,  2,  3,  2],
        [ 7,  8,  8,  9],
        [ 4,  5,  6, 10]], dtype=torch.int32)


In [None]:
# The tensors along dimension 0 are [1, 2, 3, 2], [4, 5, 6, 10], [1, 2, 3, 2], [7, 8, 8, 9], [4, 5, 6, 10]. The unique elements (tensors) 
# are [1, 2, 3, 2], [4, 5, 6, 10], [7, 8, 8, 9]
# [1, 2, 3, 2] --> 2 times
# [4, 5, 6, 10] --> 2 times
# [7, 8, 8, 9] --> 1 time
t5_unique_elements, t5_counts = torch.unique(input=t4, dim=0, return_counts=True)
print(f"shape of t5_unique_elements: {t5_unique_elements.shape}")
print(f"t5_unique_elements: {t5_unique_elements}")
print("-" * 150)
print(f"shape of t5_counts: {t5_counts.shape}")
print(f"t5_counts: {t5_counts}")

shape of t5_unique_elements: torch.Size([3, 4])
t5_unique_elements: tensor([[ 1,  2,  3,  2],
        [ 4,  5,  6, 10],
        [ 7,  8,  8,  9]], dtype=torch.int32)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of t5_counts: torch.Size([3])
t5_counts: tensor([2, 2, 1])


In [None]:
# NEED TO UNDERSTAND HOW THE OUTPUT IS GETTING GENERATED HERE.
t6_unique_elements, t6_counts = torch.unique(input=t4, dim=1, return_counts=True)
print(f"shape of t6_unique_elements: {t6_unique_elements.shape}")
print(f"t6_unique_elements: {t6_unique_elements}")
print("-" * 150)
print(f"shape of t6_counts: {t6_counts.shape}")
print(f"t6_counts: {t6_counts}")

shape of t6_unique_elements: torch.Size([5, 4])
t6_unique_elements: tensor([[ 1,  2,  2,  3],
        [ 4,  5, 10,  6],
        [ 1,  2,  2,  3],
        [ 7,  8,  9,  8],
        [ 4,  5, 10,  6]], dtype=torch.int32)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of t6_counts: torch.Size([4])
t6_counts: tensor([1, 1, 1, 1])


## [torch.unbind](https://pytorch.org/docs/stable/generated/torch.unbind.html)

In [22]:
# This extracts the individual tensors from the given tensor along the specified dimension.

In [26]:
t7 = torch.arange(start=0, end=20, dtype=torch.int32).reshape(shape=(4, 5))
print(f"shape of t7: {t7.shape}")
print(f"t7: {t7}")

shape of t7: torch.Size([4, 5])
t7: tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]], dtype=torch.int32)


In [28]:
t8_groups = torch.unbind(input=t7, dim=0)
print(f"type of t8: {type(t8_groups)}")
print(f"t8: {t8_groups}")

type of t8: <class 'tuple'>
t8: (tensor([0, 1, 2, 3, 4], dtype=torch.int32), tensor([5, 6, 7, 8, 9], dtype=torch.int32), tensor([10, 11, 12, 13, 14], dtype=torch.int32), tensor([15, 16, 17, 18, 19], dtype=torch.int32))
