In [37]:
%matplotlib inline

----------------------------
#Session 0 
----------------------------

-------------------------
# Tensors

Tensors are a specialized data structure that are very similar to arrays and matrices.
In PyTorch, we use tensors to encode the inputs and outputs of a model, as well as the model’s parameters.

Tensors are similar to `NumPy’s <https://numpy.org/>`_ ndarrays, except that tensors can run on GPUs or other hardware accelerators. In fact, tensors and
NumPy arrays can often share the same underlying memory, eliminating the need to copy data (see `bridge-to-np-label`). Tensors
are also optimized for automatic differentiation (we'll see more about that later in the `Autograd <autogradqs_tutorial.html>`__
section). If you’re familiar with ndarrays, you’ll be right at home with the Tensor API. If not, follow along!


In [38]:
import torch
import numpy as np

## Initializing a Tensor

Tensors can be initialized in various ways. Take a look at the following examples:

**Directly from data**

Tensors can be created directly from data. The data type is automatically inferred.



In [39]:
data = [[1, 2],[3, 4]]
x_data = torch.tensor(data)

**From a NumPy array**

Tensors can be created from NumPy arrays (and vice versa - see `bridge-to-np-label`).



In [40]:
np_array = np.array(data)
x_np = torch.from_numpy(np_array)

**From another tensor:**

The new tensor retains the properties (shape, datatype) of the argument tensor, unless explicitly overridden.



In [41]:
x_ones = torch.ones_like(x_data) # retains the properties of x_data
print(f"Ones Tensor: \n {x_ones} \n")

x_rand = torch.rand_like(x_data, dtype=torch.float) # overrides the datatype of x_data
print(f"Random Tensor: \n {x_rand} \n")

Ones Tensor: 
 tensor([[1, 1],
        [1, 1]]) 

Random Tensor: 
 tensor([[0.1823, 0.6596],
        [0.4568, 0.9685]]) 



**With random or constant values:**

``shape`` is a tuple of tensor dimensions. In the functions below, it determines the dimensionality of the output tensor.



In [42]:
shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")

# You can also try zeros_like, rand_like and ones_like
another_rand_tensor = torch.rand_like(zeros_tensor)
another_ones_tensor = torch.ones_like(zeros_tensor)
another_zeros_tensor = torch.zeros_like(zeros_tensor)

print(f"Other Random Tensor: \n {another_rand_tensor} \n")
print(f"Other Ones Tensor: \n {another_ones_tensor} \n")
print(f"Other Zeros Tensor: \n {another_zeros_tensor}")

Random Tensor: 
 tensor([[0.8461, 0.7135, 0.8501],
        [0.3846, 0.4067, 0.9118]]) 

Ones Tensor: 
 tensor([[1., 1., 1.],
        [1., 1., 1.]]) 

Zeros Tensor: 
 tensor([[0., 0., 0.],
        [0., 0., 0.]])
Other Random Tensor: 
 tensor([[0.3835, 0.8667, 0.2404],
        [0.9099, 0.6373, 0.0080]]) 

Other Ones Tensor: 
 tensor([[1., 1., 1.],
        [1., 1., 1.]]) 

Other Zeros Tensor: 
 tensor([[0., 0., 0.],
        [0., 0., 0.]])


--------------




## Attributes of a Tensor

Tensor attributes describe their shape, datatype, and the device on which they are stored.



In [43]:
tensor = torch.rand(3,4)

print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")

Shape of tensor: torch.Size([3, 4])
Datatype of tensor: torch.float32
Device tensor is stored on: cpu


--------------




## Operations on Tensors

Over 100 tensor operations, including arithmetic, linear algebra, matrix manipulation (transposing,
indexing, slicing), sampling and more are
comprehensively described `here <https://pytorch.org/docs/stable/torch.html>`__.

Each of these operations can be run on the GPU (at typically higher speeds than on a
CPU). If you’re using Colab, allocate a GPU by going to Runtime > Change runtime type > GPU.

By default, tensors are created on the CPU. We need to explicitly move tensors to the GPU using
``.to`` method (after checking for GPU availability). Keep in mind that copying large tensors
across devices can be expensive in terms of time and memory!



In [44]:
# We move our tensor to the GPU if available
if torch.cuda.is_available():
    tensor = tensor.to('cuda')
    print(f"Device tensor is stored on: {tensor.device}")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

tensor = tensor.to(device) #same as tensor.to('cuda') or tensor.to('cpu')

Device tensor is stored on: cuda:0


Try out some of the operations from the list.
If you're familiar with the NumPy API, you'll find the Tensor API a breeze to use.




**Standard numpy-like indexing and slicing:**



In [45]:
tensor = torch.ones(3, 4)
print('First row: ', tensor[0])
print('First column: ', tensor[:, 0])
print('Last column:', tensor[..., -1])
tensor[:,1] = 0
print(tensor)

First row:  tensor([1., 1., 1., 1.])
First column:  tensor([1., 1., 1.])
Last column: tensor([1., 1., 1.])
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])


**Joining tensors** You can use ``torch.cat`` to concatenate a sequence of tensors along a given dimension.
See also `torch.stack <https://pytorch.org/docs/stable/generated/torch.stack.html>`__,
another tensor joining op that is subtly different from ``torch.cat``.



In [46]:
t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(t1)

tensor([[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.]])


**Arithmetic operations**



In [47]:
# This computes the matrix multiplication between two tensors. y1, y2, y3 will have the same value
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)

y3 = torch.empty_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

print(f'Equivalent computation: {torch.all(y1==y2) and torch.all(y2==y3)}')


# This computes the element-wise product. z1, z2, z3 will have the same value
z1 = tensor * tensor
z2 = tensor.mul(tensor)

z3 = torch.empty_like(tensor)
torch.mul(tensor, tensor, out=z3)
print(f'Equivalent computation: {torch.all(z1==z2) and torch.all(z2==z3)}')

Equivalent computation: True
Equivalent computation: True


**in-place vs out-of-place operations**



In [48]:
# methods ending with a _ are in place
x = torch.zeros(2,3)
print(f'x at the begining:\n {x}')
x.add(1)
print(f'x after out-of-place op:\n {x}')
x.add_(1)
print(f'x after in-place op:\n {x}')

# when such in-place variant does not exist, one may rely on the `out` parameter
torch.add(x,1, out=x)
print(f'x after another in-place op:\n {x}')

x at the begining:
 tensor([[0., 0., 0.],
        [0., 0., 0.]])
x after out-of-place op:
 tensor([[0., 0., 0.],
        [0., 0., 0.]])
x after in-place op:
 tensor([[1., 1., 1.],
        [1., 1., 1.]])
x after another in-place op:
 tensor([[2., 2., 2.],
        [2., 2., 2.]])


**Single-element tensors** If you have a one-element tensor, for example by aggregating all
values of a tensor into one value, you can convert it to a Python
numerical value using ``item()``:



In [49]:
agg = tensor.sum()
agg_item = agg.item()
print(agg_item, type(agg_item))

9.0 <class 'float'>


-------------------------
# Exercises on Tensor manipulation

## Exercise 1: 
Compute $\Vert t_1-t_2\Vert^2$ and return the result as a `float`. 
You should use `torch.square` and `torch.sum` in your solution.

<u>Hint:</u> Beware, that $t_1$ and $t_2$ have different shapes, so you should resize them to a common shape first.

In [50]:
debug = False
def exercise1(t1,t2):
  t1.resize_(t2.shape)
  t3 = torch.sum(torch.square(t1 -t2))
  return t3.item()

def test_exercise1():
  t1 = torch.ones((3,4))
  t2 = torch.arange(0,12)
  result = exercise1(t1,t2)
  expected = 386
  if isinstance(result, float) and result == expected:
    print('Exercise 1: OK')
  else:
    print('Exercise 1: NOK')
    if not isinstance(result, float):
      print('Hint: don\'t forget to return a float and not a tensor (with item())')
    if debug:
      print(f' result : {result}')
      print(f' expected : {expected}')

test_exercise1()

Exercise 1: OK


## Exercice 2: Monte-carlo estimation of $\pi/4$
To estimate $\pi/4$, we can generate $n$ random samples in the unit square, and count the ratio that have a norm less than 1. To do so in pytorch, generate a random tensor of shape $n\times 2$ and compute the norm of each column, and guess the rest of the solution.
![Illustration](https://upload.wikimedia.org/wikipedia/commons/thumb/8/84/Pi_30K.gif/440px-Pi_30K.gif)


In [51]:
debug = False

def exercise2(n_samples):
  resultat = torch.rand(n_samples,2).double()
  resultat = torch.norm(resultat,dim=1) # dim 1 pour les lignes
  resultat = (resultat <=1).sum().double()
  return resultat/n_samples

def test_exercise2():
  result = exercise2(int(1e7))
  expected = torch.tensor([np.pi/4]).double()
  print(expected)
  if  torch.isclose(result, expected, atol=1e-3):
    print('Exercise 2: OK')
  else:
    print('Exercise 2: NOK')
    if debug:
      print(f' result : {result.item()}')
      print(f' expected : {expected.item()}')

test_exercise2()

tensor([0.7854], dtype=torch.float64)
Exercise 2: OK


## Exercise 3: In-place computation of the Fibonacci number
Denoting $A=\begin{pmatrix} 1& 1\\ 1 &0\end{pmatrix}$, one may show that $A^n=\begin{pmatrix} F_{n+1}& F_n \\ F_n & F_{n-1}\end{pmatrix}$. Using the previous fact as well as $\frac{F_{n+1}}{F_n}\to \frac{1+\sqrt{5}}{2}$, you shall compute $\phi:=\frac{1+\sqrt{5}}{2}$ based only on inline operations. In practice, you should initialize $A$ to the correct matrix, and compute iteratively $A^{(2^n)}$ in-place (using the `out` parameter of the `torch.mm` routine).

In [52]:
debug = False
import math

def exercise3(A, n):
  '''Here you should compute A^{2n} for the Fibonnaci matrix. The computation must be in-place'''
  A.copy_(torch.tensor([[1,1],[1,0]]))
  for i in range(n):
    torch.mm(A,A,out=A) 


def test_exercise3():
  A = torch.empty((2,2))
  exercise3(A,3)
  result = A[0,0]/A[1,0]
  expected = (1+torch.sqrt(torch.tensor(5.)))/2
  if  torch.isclose(result, expected, 1e-3):
    print('Exercise 3: OK')
  else:
    print('Exercise 3: NOK')
    if debug:
      print(f' result : {result}')
      print(f' expected : {expected}')

test_exercise3()

Exercise 3: OK


## Exercise 4: Hamming correcting code
Here you will implement the Hamming correcting code scheme. To do so you will need to implement the following functions:
* `nbits(x)`  which computes the maximum number of bits necessary to encode the entries of the int tensor `x`
* `int2bits(x)`  which converts an int tensor into its binary version, the output contain one extra dimension and is composed of boolean
* `bits2int(bits)` which is the reverse of the previous routine
* `nparity_bits(n)` compute the number of parity bits given the number of data bits (using exhaustive search and the constraint $2^r >= m + r + 1$)
* `init_parity_bits(bits)` insert parity bits (initialized with 0) at the right positions (at all power of two locations in 1-indexing)
* `hamming_encode(bits)` which actually call the previous function and then actually compute the true values of the parity bits. This values are the parity of all subsequent bits which 1-index has binary decomposition with a 1 at k$^\text{th}$ position (see the following <a href=#tab>tabular</a> or the more complete description here :https://en.wikipedia.org/wiki/Hamming_code#General_algorithm)
* `locate_errors_in_0indexing(bits,r)` returns the locations (0-indexed) of error locations (assuming that at most one bit was fliped per entry)
* `hamming_decode(bits,r)` which actually call the previous function and then correct each error (one at most per entry and return the corrected bits (except for the parity bits which are dropped)
<a id="tab"/>
    
<dl><dd><table class="wikitable" style="text-align:center;">

<tbody><tr>
<th colspan="2"><span style="color:#888">Bit position</span>
</th>
<th><span style="color:#888">1</span></th>
<th><span style="color:#888">2</span></th>
<th><span style="color:#888">3</span></th>
<th><span style="color:#888">4</span></th>
<th><span style="color:#888">5</span></th>
<th><span style="color:#888">6</span></th>
<th><span style="color:#888">7</span></th>
<th><span style="color:#888">8</span></th>
<th><span style="color:#888">9</span></th>
<th><span style="color:#888">10</span></th>
<th><span style="color:#888">11</span></th>
<th><span style="color:#888">12</span></th>
<th><span style="color:#888">13</span></th>
<th><span style="color:#888">14</span></th>
<th><span style="color:#888">15</span></th>
<th><span style="color:#888">16</span></th>
<th><span style="color:#888">17</span></th>
<th><span style="color:#888">18</span></th>
<th><span style="color:#888">19</span></th>
<th><span style="color:#888">20</span>
</th>
<td rowspan="7">…
</td></tr>
<tr>
<th colspan="2">Encoded data bits
</th>
<th style="background-color: #90FF90;">p1
</th>
<th style="background-color: #90FF90;">p2</th>
<th>d1
</th>
<th style="background-color: #90FF90;">p4</th>
<th>d2</th>
<th>d3</th>
<th>d4
</th>
<th style="background-color: #90FF90;">p8</th>
<th>d5</th>
<th>d6</th>
<th>d7</th>
<th>d8</th>
<th>d9</th>
<th>d10</th>
<th>d11
</th>
<th style="background-color: #90FF90;">p16</th>
<th>d12</th>
<th>d13</th>
<th>d14</th>
<th>d15
</th></tr>
<tr>
<th rowspan="5">Parity<br>bit<br>coverage
</th>
<th style="background-color: #90FF90;">p1
</th>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td>
</td></tr>
<tr>
<th style="background-color: #90FF90;">p2
</th>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td>
</td></tr>
<tr>
<th style="background-color: #90FF90;">p4
</th>
<td></td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td></tr>
<tr>
<th style="background-color: #90FF90;">p8
</th>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td></td>
<td></td>
<td></td>
<td></td>
<td>
</td></tr>
<tr>
<th style="background-color: #90FF90;">p16
</th>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td></td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;</td>
<td data-sort-value="Yes" style="background: #DFD; vertical-align: middle; text-align: center;" class="table-yes2">&#10004;
</td></tr></tbody></table></dd></dl>



In [54]:
def nbits(x):
    n = 1
    max_val = int(torch.max(x).item())
    n = math.ceil(math.log2(max_val + 1)) if max_val > 0 else 1
    return n
    

def test_nbits():
    x = torch.arange(0, 128)
    n1 = nbits(x)
    n2 = nbits(x+1)
    n3 = nbits(x+1)
    if (n1,n2,n3) == (7,8,8):
        print('test_nbits [OK]')
    else:
        print('test_nbits [NOK]')
    
test_nbits()

def int2bits(x):
    n = nbits(x)
    bits = torch.empty(list(x.shape)+[n])
    for i in range(n):
        bits[...,i] = ((x>> i) &  1)
    return bits.to(bool)


def test_int2bits():
    x = torch.arange(0, 8)
    bits = int2bits(x)
    if (bits.dtype == torch.bool) and torch.all(bits == torch.tensor([[0, 0, 0],[1, 0, 0],[0, 1, 0],[1, 1, 0],[0, 0, 1],[1, 0, 1],[0, 1, 1],[1, 1, 1]])):
        print('test_int2bits [OK]')
    else:
        print('test_int2bits [NOK]')
    
test_int2bits()


def bits2int(bits):
    n = bits.shape[-1]
    x = torch.zeros(bits.shape[:-1])
    for i in range(n):
        x += bits[...,i].to(torch.int64) * (1 << i)
    return x.to(int)

def test_bits2int():
    bits = torch.tensor([[0, 0, 0],[1, 0, 0],[0, 1, 0],[1, 1, 0],[0, 0, 1],[1, 0, 1],[0, 1, 1],[1, 1, 1]]).to(bool)
    x = bits2int(bits)
    if (x.dtype == torch.int64) and torch.all(x == torch.arange(0, 8)):
        print('test_bits2int [OK]')
    else:
        print('test_bits2int [NOK]')
    
test_bits2int()


def nparity_bits(n):
    r = 0
    while (2 ** r) < (n + r + 1):
        r += 1
    return r
        
def test_nparity_bits():
    r10 = nparity_bits(10)
    r11 = nparity_bits(11)
    r12 = nparity_bits(12)
    if (r10,r11,r12) == (4,4,5):
        print('test_nparity_bits [OK]')
    else:
        print('test_nparity_bits [NOK]')

test_nparity_bits()


def init_parity_bits(bits):
    n = bits.shape[-1]
    r = nparity_bits(n)
    nb_bits = n+r

    newbits = torch.zeros(bits.shape[:-1] + (nb_bits,))
    data_idx = 0
    for i in range(nb_bits):
        if (i + 1) & i: 
            newbits[..., i] = bits[..., data_idx]
            data_idx += 1

    return newbits.to(bool)

def test_init_parity_bits():
    bits = torch.rand(10, 8)>0.5
    all_bits = init_parity_bits(bits)
    ntotal = all_bits.shape[-1]
    test = (ntotal == 12)
    z = torch.zeros_like(bits[...,0])
    for k in range(4):
        test = test and torch.allclose(z, all_bits[[...,(1<<k)-1]])

    if test:
        print('test_init_parity_bits [OK]')
    else:
        print('test_init_parity_bits [NOK]')

        
test_init_parity_bits()


def hamming_encode(bits):
    n = bits.shape[-1]
    r = nparity_bits(n)
    bits = init_parity_bits(bits)   
    n = bits.shape[-1]
    
    for k in range(r):  
        parity_pos = (1 << k) - 1  
        parity = torch.zeros(bits.shape[:-1], dtype=torch.bool)
        for i in range(1, n + 1):  
            if i & (1 << k):  
                parity = parity ^ bits[..., i-1] 
        
        bits[..., parity_pos] = parity
    
    return bits

def test_hamming_encode():
    bits = torch.rand(10, 8)>0.5
    bits[0,:] = torch.Tensor([1, 0, 1, 1, 0, 0, 1, 1]).to(bool)
    expected = torch.Tensor([1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1]).to(int)
    all_bits = hamming_encode(bits)
    ntotal = all_bits.shape[-1]
    test = (ntotal == 12)
    test = test and  torch.all(all_bits[0,:] == expected)
    if test:
        print('test_hamming_encode [OK]')
    else:
        print('test_hamming_encode [NOK]')

        
test_hamming_encode()

import torch.nn.functional as F #usefull for F.one_hot

def locate_errors_in_0indexing(bits, r):
    n = bits.shape[-1]
    error_locations_1indexing = torch.zeros(bits.shape[:-1],dtype=int)
    
    for k in range(r):
        parity = torch.zeros(bits.shape[:-1], dtype=torch.bool)
        for i in range(1, n + 1): 
            if i & (1 << k):  
                parity = parity ^ bits[..., i-1]  
        
        error_locations_1indexing += parity.to(torch.int64) * (1 << k)
    
    error_locations_0indexing = error_locations_1indexing-1
    return error_locations_0indexing
    

def test_locate_errors_in_0indexing():
    bits = torch.Tensor([1, 0, 1, 1, 0, 0, 1, 1]).to(bool).repeat( [3, 1])
    r = nparity_bits(8) 
    all_bits = hamming_encode(bits)
    erroneous = all_bits.clone()
    erroneous[0,5] = not erroneous[0,5]
    erroneous[1,4] = not erroneous[1,4]
    erroneous[2,3] = not erroneous[2,3]
    error_locations_0indexing = locate_errors_in_0indexing(erroneous, r)

    if list(error_locations_0indexing) == [5,4,3]:
        print('test_locate_errors_in_0indexing [OK]')
    else:
        print('test_locate_errors_in_0indexing [NOK]')
    
test_locate_errors_in_0indexing()
    
def hamming_decode(bits, r): 
    n = bits.shape[-1]
    error_locations_0indexing = locate_errors_in_0indexing(bits, r)
    
    corrected_bits = bits.clone()
    for i in range(bits.shape[0]):  
        if len(bits.shape) > 1:  
            error_pos = error_locations_0indexing[i]
            if 0 <= error_pos < n:  
                corrected_bits[i, error_pos] = not corrected_bits[i, error_pos]
        else:  
            error_pos = error_locations_0indexing.item()
            if 0 <= error_pos < n:
                corrected_bits[error_pos] = not corrected_bits[error_pos]
    
    data_bits = torch.zeros(corrected_bits.shape[:-1] + (n - r,), dtype=torch.bool)
    data_idx = 0
    for i in range(n):
        if (i + 1) & i: 
            data_bits[..., data_idx] = corrected_bits[..., i]
            data_idx += 1
    
    return data_bits

def test_hamming_decode():
    bits = torch.Tensor([1, 0, 1, 1, 0, 0, 1, 1]).to(bool).repeat( [3, 1])
    r = nparity_bits(8) 
    all_bits = hamming_encode(bits)
    erroneous = all_bits.clone()
    erroneous[0,5] = not erroneous[0,5]
    erroneous[1,4] = not erroneous[1,4]
    erroneous[2,3] = not erroneous[2,3]
    decoded = hamming_decode(erroneous, r)

    if torch.all(decoded == bits):
        print('test_hamming_decode [OK]')
    else:
        print('test_hamming_decode [NOK]')

        
test_hamming_decode()
    

                    
    

test_nbits [OK]
test_int2bits [OK]
test_bits2int [OK]
test_nparity_bits [OK]
test_init_parity_bits [OK]
test_hamming_encode [OK]
test_locate_errors_in_0indexing [OK]
test_hamming_decode [OK]
