# Types

The largest difference in PyTorch and Python is that PyTorch only supports a small subset of types that are needed to express neural net models                                                                                                                           
1. ```Tensor``` :=  A PyTorch tensor of any dtype, dimension, or backend

2. ```Tuple[T0, T1, ..., TN]``` := A tuple containing subtypes T0, T1, etc. (e.g. Tuple[Tensor, Tensor])

3. ```bool``` := A boolean value

4. ```int``` :=  A scalar integer

5. ```float``` := A scalar floating point number

6. ```str``` := A string

7. ```List[T]```  := A list of which all members are type T

8. ```Optional[T]```  := A value which is either None or type T

9. ```Dict[K, V]``` := A dict with key type K and value type V. Only str, int, and float are allowed as key types.

10. ```T```  := A TorchScript Class

11. ```E```   :=  A TorchScript Enum

12. ```NamedTuple[T0, T1, ...]```  :=  A ```collections.namedtuple``` tuple type

Unlike Python, each variable in TorchScript function have a single static type. This makes it easier to optimize TorchScript functions

### Example (A Type Mismatch Error)

In [1]:
import torch

@torch.jit.script
def error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r

RuntimeError: 

Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
  File "<ipython-input-1-ec0db09b08ea>", line 5
@torch.jit.script
def error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
  File "<ipython-input-1-ec0db09b08ea>", line 9
    else:
        r = 4
    return r
           ~ <--- HERE


## Unsupported Typing Constructs

TorchScript does not support all features and types of the ```typing``` module. Some of these are more fundamental things that are unlikely to be added in the future while others may be added if there is enough user demand to make it a priority.

The types which are unsupported are: 
- ```typing.Any```
- ```typing.NoReturn```
- ```typing.Union```
- ```typing.Sequence```
- ```typing.Callable```
- ```typing.Literal```
- ```typing.ClassVar```
- ```typing.Final```
- ```typing.AnyStr```
- ```typing.overload```

## Default Types

By default, all parameters to a TorchScript Function are assumed to be Tensor. To specify that an argument to a TorchScript function is of another type, ```MyPy-style``` type annotations using types:

In [2]:
import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

tensor([4.2062, 4.3113, 4.0215])


An empty list is assumed to be ```List[Tensor]``` and empty dicts ```Dict[str, Tensor]```. To instantiate an empty list or dict of other types, use <i>Python 3 hints</i>

### Example (Type Annotation for Python 3)

In [7]:
# type annotations for Python 3
import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()
    
    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a ```List[Tuple[int, float]]```
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))
        
        my_dict: Dict[str, int] = {}
        return my_list, my_dict
    
x = torch.jit.script(EmptyDataStructures())

x

RecursiveScriptModule(original_name=EmptyDataStructures)

## Optional Type Refinement

TorchScript will refine the type of a variable of type ```Optional[T]``` when a comparison to ```None``` is made inside the conditional of an if-statement or checked in an ```assert```.

The compiler can reason about multiple ```None``` checks that are combined with ```and```, ```or```, and ```not```. Refinement will also occur for else blocks of if-statements that are not explicitly written.

The ```None``` check must be within the if-statement’s condition; assigning a ```None``` check to a variable and using it in the if-statement’s condition will not refine the types of variables in the check. Only local variables will be refined, an attribute like ```self.x``` will not and must assigned to a local variable to be refined.

### Example (Refining Types on Parameters and Locals)

In [11]:

import torch
import torch.nn as nn
from typing import Optional


class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super(M, self).__init__()
        # If `z` is None, its type cannot be inferred, so it must be Specified 
        self.z = z
    
    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if  x is None:
            x = 1
            x += 1
        
        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z
        
        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

module

RecursiveScriptModule(original_name=M)

## TorchScript Classes

Python classes can be used in TorchScript if they are annotated with ```@torch.jit.script```, similar to how you would declare a TorchScript function:

In [15]:
%run Code/TorchScriptCodeVersionCodeFiles/TorchScriptClasses/ErrorExplain.py

It works..!!


The subset is restricted:
- All Functions must be valid TorchScript Function (including ```__init__()```)
- Classes must be new-style classes, as only ```__new__()``` is used for pybind11.
- TorchScript Classes are statically typed. Members can only be declared by assigning to self in the ```__init__()``` method.

Like, assigning ```self``` outside of the ```__init__()``` method.

In [17]:
%run Code/TorchScriptCodeVersionCodeFiles/TorchScriptClasses/SelfoutsideInit_Error.py

RuntimeError: 
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
  File "D:\ML\PyTorch\PyTorch_Manual\TorchScript Language Reference\Code\TorchScriptCodeVersionCodeFiles\TorchScriptClasses\SelfoutsideInit_Error.py", line 6
    def assign_x(self):
        self.x = torch.rand(2, 3)
        ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE


- No expressions except method definitions are allowed in the body of the class.
- No support for inheritance or any other polymorphism strategy, except for the inheriting from ```object``` to specify a new-style class.

After a class is defined, it can be used in both TorchScript and Python interchangeably like any other TorchScript type:

In [18]:
%run Code/TorchScriptCodeVersionCodeFiles/TorchScriptClasses/DeclareTorchScriptClass.py

tensor([[0.7308, 1.4268, 0.9483],
        [0.7914, 1.0562, 1.7423]])


## TorchScript Enums

Python Enums can be used in TorchScript without any extra annotation or code:

```python
from enum import Enum

class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y
```

After an Enum is defined, it can be used in both TorchScript and Python interchangeably like any other TorchScript Type.

The Type of the values of an enum must be ```int```, ```float``` or ```str```. <b>All values must be of the same type</b>

In [20]:
%run Code/TorchScriptCodeVersionCodeFiles/TorchScriptEnums.py # It works...!!

## Named Tuples

Types produced by ```collections.namedtuple``` can be used in TorchScript

```python
# Types produced by collections.namedtuple can be used in TorchScript.
import torch
from collections import namedtuple

Point = namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))
```

In [21]:
%run Code/TorchScriptCodeVersionCodeFiles/NamedTuples.py

tensor([0.7753, 1.3711, 1.3238])


# Important

## Function Calls

Calls to <i>builtin-functions</i>

```python
torch.rand(3, dtype=torch.int)
```

In [22]:
torch.rand(3, dtype=torch.int) # Error !! .....??

RuntimeError: "check_uniform_bounds" not implemented for 'Int'

Calls to other script functions:

In [23]:
@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

## Method Calls

Calls to methods of built-in types like Tensor: ```x.mm(y)```

On modules, methods must be compiled before they can be called. To TorchScript Compiler recursively compiles methods ir sees when compiling other methods.

By default, compilation starts on ```forward()``` method. Any methods called by ```forward``` will be compiled, and any methods called by those methods, and so on. To start compilation at a method other than ```forward```, use the ```@torch.jit.export``` decorator (forward implicitly is marked ```@torch.jit.export```).

Calling a submodule directly (i.e ```self.resnet(input)```) is equivalent to calling its ```forward``` method (i.e ```self.resnet.forward(input)```)