# Implementing an Operation

## What Should Be Implemented as an Operation?

The first step is to ensure that what you want to implement is actually an operation.
Most operations are non-trainable, but this is not a strict requirement.

Examples of operations are `Reshape`, `Concatenate`, `Dropout`.

**NOTE:** Some operations are trainable. This is useful if the standard constructor of a trainable layer is not well suited for Deeplay, or if a layer needs a custom forward pass. This is the case for attention layers, for example. In this case it's important to ensure that the operation is not actually a operation. If the module contains several layers, it should instead be implemented as a operation.

## Implementing an Operation

Here you'll see the steps you should follow to implement an operation in Deeplay. You'll do this by implementing the `Reshape` operation.

### 1. Create a New File

The first step is to create a new file in the `deeplay/ops` directory. It
can be in a deeper subdirectory if it makes sense.

**The base class.**
Some operations have a common base class. These include `ShapeOp` and `MergeOp`.
If your operation fits into one of these categories, you should inherit from the
base class. If not, you should inherit from `DeeplayModule`.


This example implements the `Reshape` operation.

In [None]:
from deeplay.ops.shape import ShapeOp

class Reshape(ShapeOp):
    def __init__(self, *shape, copy=False):
        self.shape = shape
        self.copy = copy

    def forward(self, x):
        x = x.view(*self.shape)
        if self.copy:
            x = x.clone()
        return x

### 2. Add Annotations

It is important to add annotations to the class and methods to ensure that the
user knows what to expect. This is also useful for the IDE to provide 
autocomplete.

In [None]:
from deeplay.ops.shape import ShapeOp
import torch

class Reshape(ShapeOp):
    
    shape: Tuple[int, ...]
    copy: bool
    
    def __init__(
        self, 
        *shape: int, 
        copy: bool = False,
    ) -> None: 
        self.shape = shape
        self.copy = copy

    def forward(
        self, 
        x: torch.Tensor, 
    ) -> torch.Tensor:
        x = x.view(*self.shape)
        if self.copy:
            x = x.clone()
        return x

### 3. Document the Operation

The next step is to document the operation. This should include a description of 
the operation, the input and output shapes, and the arguments that can be passed to
the operation.

In [None]:
class Reshape(ShapeOp):
    """A operation for reshaping a tensor.

    This operation reshapes a tensor to a new shape. The new shape is specified 
    as a tuple of integers. The `copy` parameter controls whether the reshaped 
    tensor is a view of the original tensor or a copy.

    Parameters
    ----------
    *shape : int
        The new shape of the tensor.
    copy : bool
        Whether to return a copy of the reshaped tensor.

    Attributes
    ----------
    shape : Tuple[int, ...]
        The new shape of the tensor.
    copy : bool
        Whether to return a copy of the reshaped tensor.
    
    Input
    -----
    x : torch.Tensor (Any, ...)
        The input tensor to reshape.
    
    Output
    ------
    y : torch.Tensor
        The reshaped tensor (*shape).

    Evaluation
    ----------
    y = x.view(*shape) if not copy else x.view(*shape).clone()

    Examples
    --------
    >>> operation = Reshape(3, 6, copy=True).build()
    >>> x = torch.randn(2, 9)
    >>> y = operation(x)
    >>> y.shape
    torch.Size([3, 6])

    """
    
    def __init__( 
        self, 
        *shape: int, 
        copy: bool = False,
    ) -> None: 
        self.shape = shape
        self.copy = copy

    def forward(  
        self, 
        x: torch.Tensor,  
    ) -> torch.Tensor:
        """Forward pass of the reshape operation.
        
        Parameters
        ----------
        x : torch.Tensor
            The input tensor to reshape.
        
        Returns
        -------
        torch.Tensor
            The reshaped tensor.
        """
        x = x.view(*self.shape)
        if self.copy:
            x = x.clone()
        return x