# Overview

> Note: If you overuse pinned memory, it can cause serious problems when running low on RAM, and you should be aware that pinning is often en expensive operation.

Host to GPU copies are much fater when they originate from pinned(page-locked) memory. CPU tensors and storages expose a `pin_memory()` method, that returns a copy of the object, with data put in a pinned region.

Once we pin a tensor or storage, we can use asynchronous GPU copies. Just pass an additional `non_blocking=True` argument to a `to()` or a `cuda()` call. This can be used to overlap data transfers with computation.

For data loading, passing `pin_memory=True` to a DataLoader will automatically put the fecthed data Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled GPUs.

The default memory pinning logic only recognizes Tensors and maps and iterables containing Tensors. By default, if the pinning logic sees a batch that is a custom type(which will occur if you gave a collate_fn that returns a custom batch type), or if each element of your batch is a custom type, the pinning logic will not recognize them, and it will return that batch(or those elements) without pinning the memory. To enable memeory pinning for custom batch or data type(s), define a `pin_memory()` mrthod on your custom type(s). 


See the example below:

In [1]:
import torch
from torch.utils.data import DataLoader, TensorDataset

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data=list(zip(*data))
        self.inp=torch.stack(transposed_data[0],0)
        self.tgt=torch.stack(transposed_data[1],0)
        
    def pin_memory(self):
        self.inp=self.inp.pin_memory()
        self.tgt=self.tgt.pin_memory()
        return self
    
def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps=torch.arange(10*5, dtype=torch.float32).view(10,5)
tgts=torch.arange(10*5, dtype=torch.float32).view(10,5)
dataset=TensorDataset(inps, tgts)

loader=DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

True
True
True
True
True
True
True
True
True
True
