# 16. Operator Overloading

> There are some things that I kind of feel torn about, like operator overloading. I left out operator overloading as a fairly personal choice because I had seen too many people abuse it in C++.
>> James Gosling, creator of Java

## Unary Operators

_The Python Language Reference_  lists three unary operators:

`-`, implemented by `__neg__`

`+`, implemented by `__pos__`

`~`, implemented by `__invert__`



## Overloading `+` for Vector Addition

Adding two Euclidean vectors results in a new vector in which the components are the pairwise additions of the components of the operands:

In [58]:
from array import array
import reprlib
import math

In [59]:
class Vector:
    typecode = 'd'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __len__(self):
        return len(self._components)

    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components) 
        components = components[components.find('['):-1] 
        return f'Vector({components})'
    
    def __str__(self):
        return str(tuple(self))

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) +
                bytes(self._components))
    
    def __eq__(self, other):
        return tuple(self) == tuple(other)
    
    def __abs__(self):
        return math.hypot(*self)
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode) 
        return cls(memv)

In [60]:
import itertools

def __add__(self, other):
    pairs = itertools.zip_longest(self, other, fillvalue=0.0)
    return Vector(a + b for a, b in pairs)

Vector.__add__ = __add__

In [61]:
v1 = Vector( [3, 4, 5] )
v1 + (10, 20, 30)

Vector([13.0, 24.0, 35.0])

In [62]:
from vector2d_v3_slots import Vector2d

v2d = Vector2d(1, 2)
v1 + v2d

Vector([4.0, 6.0, 5.0])

In [63]:
try:
    (10, 20, 30) + v1
except TypeError as e:
    print(f"{e=}")

e=TypeError('can only concatenate tuple (not "Vector") to tuple')


In [64]:
def __radd__(self, other):
    return self + other

Vector.__radd__ = __radd__

In [65]:
try:
    print(f"{(10, 20, 30) + v1=}")
except TypeError as e:
    print(f"{e=}")

(10, 20, 30) + v1=Vector([13.0, 24.0, 35.0])


In [66]:
try:
    v1 + 1
except TypeError as e:
    print(f"{e=}")

e=TypeError("'int' object is not iterable")


In [67]:
try:
    v1 + 'ABC'
except TypeError as e:
    print(f"{e=}")

e=TypeError("unsupported operand type(s) for +: 'float' and 'str'")


In [68]:
def __add__(self, other):
    try:
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a + b for a, b in pairs)
    except TypeError:
        return NotImplemented
    
def __radd__(self, other):
    return self + other

Vector.__add__ = __add__
Vector.__radd__ = __radd__

In [69]:
try:
    v1 + 1
except TypeError as e:
    print(f"{e=}")

e=TypeError("unsupported operand type(s) for +: 'Vector' and 'int'")


In [70]:
try:
    v1 + 'ABC'
except TypeError as e:
    print(f"{e=}")

e=TypeError("unsupported operand type(s) for +: 'Vector' and 'str'")


## Overloading `*` for Scalar Multiplication

In [71]:
def __mul__(self, scalar):
    try:
        factor = float(scalar)
    except TypeError:
        return NotImplemented
    return Vector(n * factor for n in self)

def __rmul__(self, scalar):
    return self * scalar

Vector.__mul__ = __mul__
Vector.__rmul__ = __rmul__

In [72]:
v1 = Vector([1.0, 2.0, 3.0])
14 * v1

Vector([14.0, 28.0, 42.0])

In [73]:
v1 * True

Vector([1.0, 2.0, 3.0])

In [74]:
from fractions import Fraction
v1 * Fraction(1, 3)

Vector([0.3333333333333333, 0.6666666666666666, 1.0])

## Using `@` as an Infix Operator

The `@` sign is well-known as the prefix of function decorators, but since 2015, it can also be used as an infix operator. For years, the dot product was written as `numpy.dot(a, b)` in NumPy. The function call notation makes longer formulas harder to translate from mathematical notation to Python, so the numerical computing community lobbied for "PEP 465—A dedicated infix operator for matrix multiplication", which was implemented in Python 3.5. Today, you can write `a @ b` to compute the dot product of two NumPy arrays.

In [75]:
va = Vector([1, 2, 3 ])
vz = Vector([5, 6, 7 ])
try:
    va @ vz == 38
except TypeError as e:
    print(f"{e=}")

e=TypeError("unsupported operand type(s) for @: 'Vector' and 'Vector'")


In [76]:
from collections import abc

def __matmul__(self, other):
    if (isinstance(other, abc.Sized) and
        isinstance(other, abc.Iterable)):
        if len(self) == len(other):
            return sum(a*b for a, b in zip(self, other))
        else:
            raise ValueError('@ requires vectors of equal length.')
    else:
        return NotImplemented
    
def __rmatmul__(self, other):
    return self @ other

Vector.__matmul__ = __matmul__
Vector.__rmatmul__ = __rmatmul__

In [77]:
va = Vector([1, 2, 3 ])
vz = Vector([5, 6, 7 ])
try:
    print(f"{va @ vz == 38=}")
except TypeError as e:
    print(f"{e=}")

va @ vz == 38=True


## Wrapping-Up Arithmetic Operators

### _Infix operator method names_

| Operator | Forward | Reverse | In-place | Description |
| :------: | :------ | :------ | :------- | :---------- |
| `+`  | __add__ | __radd__ | __iadd__ | Addition or concatenation |
| `-`  | __sub__ | __rsub__ | __isub__ | Substraction | 
| `*`  | __mul__ | __rmul__ | __imul__ | Multiplication or repetition |
| `/`  | __truediv__ | __rtruediv__ | __itruediv__ | True division | 
| `//` | __floordiv__ | __rfloordiv__ | __ifloordiv__ | Floor division |
| `%`  | __mod__ | __rmod__ | __imod__ | Modulo |
| `divmod()` | __divmod__ | __rdivmod__ | __idivmod__ | Returns tuple of floor division quotient and modulo |
| `**`, `pow()` | __pow__ | __rpow__ | __ipow__ | Exponentiation | 
| `@` | __matmul__ | __rmatmul__ | __imatmul__ | Matrix multiplication | 
| `&` | __and__ |  __rand__ | __iand__ | Bitwise and | 
| `\|` | __or__ | __ror__ | __ior__ | Bitwise or | 
| `^` | __xor__ | __rxor__ | __ixor__ | Bitwise xor |
| `<<` | __lshift__ | __rlshift__ | __ilshift__ |  Bitwise shift left | 
| `>>` | __rshift__ | __rrshift__ | __irshift__ | Bitwise shift right | 


### Rich Comparison Operators

| Group | Infix operator | Forward method call | Reverse method call | Fallback |
| :---- | :------------- | :------------------ | :------------------ | :------- |
| **Equality**  | `a == b` | `a.__eq__(b)` | `b.__eq__(a)` | Return `id(a) == id(b)` |
| | `a != b` | `a.__ne__(b)` | `b.__ne__(a)` | Return `not (a == b)` | 
| **Ordering** | `a > b` | `a.__gt__(b)` | `b.__lt__(a)` | Return `TypeError` | 
| | `a > b` | `a.__lt__(b)` | `b.__gt__(a)` | Return `TypeError` |
| | `a >= b` | `a.__ge__(b)` | `b.__le__(a)` | Return `TypeError` |
| | `a <= b` | `a.__le__(b)` | `b.__ge__(a)` | Return `TypeError` |

In [78]:
def __eq__(self, other):
    return (len(self) == len(other) and
            all(a==b for a, b in zip(self, other)))

Vector.__eq__ = __eq__

In [79]:
va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1, 4))
va == vb

True

In [86]:
vc = Vector([1, 2])
v2d = Vector2d(1,2)
try:
    print(f"{vc == v2d=}")
except Exception as e:
    print(f"{e=}")

vc == v2d=True


In [87]:
def __len2d__(self):
    return 2

Vector2d.__len__ = __len2d__

In [88]:
vc = Vector([1, 2])
v2d = Vector2d(1,2)
try:
    print(f"{vc == v2d=}")
except Exception as e:
    print(f"{e=}")

vc == v2d=True


In [90]:
t3 = (1, 2, 3)
va == t3

True

In [91]:
def __eq__(self, other):
    if isinstance(other, Vector):
        return (len(self) == len(other) and
                all(a==b for a, b in zip(self, other)))
    else:
        return NotImplemented
    
Vector.__eq__ = __eq__

In [92]:
t3 = (1, 2, 3)
va == t3

False

## Augmented Assignment Operators

In [93]:
v1 = Vector([1, 2, 3])
v1_alias = v1

In [94]:
id(v1)

140059492755088

In [95]:
v1 += Vector( [4, 5, 6])
v1

Vector([5.0, 7.0, 9.0])

In [96]:
id(v1)

140059126406160

In [97]:
v1_alias

Vector([1.0, 2.0, 3.0])

In [98]:
v1 *= 11
v1

Vector([55.0, 77.0, 99.0])

In [99]:
id(v1)

140059126399504