# Downstream Task Example

In this guide we'll give you all the information needed to implement a downstream task and use our pre-training (or the pre-training obtained from our code with your data). 
Every section explains one file and give a small example of what is important to implement to make everything work.

## 1. Dataloader

Implement a DataModule. The class should be able to get the images and the annotations needed for the training. It needs to have a __ __getitem__ __ method which returns a single sample of the dataset.

Here an example:

In [None]:
class UAVDataModule(pl.LightningDataModule):
    
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

    def setup(self) -> None:
        """ Create dataset. """
        path_to_split_info = self.cfg['data']['split']

        with open(path_to_split_info) as istream:
          split_info = yaml.safe_load(istream)

        path_to_dataset = self.cfg['data']['path_to_dataset']

        train_filenames = split_info['train']
        
        # similarly you can define a dataset for the test and validation set
        self._uav_train = UAVDataset(path_to_dataset,
                                     train_filenames,
                                     # if you need a pre processing step for images, transformations is the way to go
                                     transformations=get_transformations(self.cfg)
                                     )

    # similarly you can define a dataloaser for the test and validation set
    def train_dataloader(self) -> DataLoader:
        shuffle: bool = self.cfg['data']['shuffle_train']
        batch_size: int = self.cfg['data']['batch_size_train']
        n_workers: int = self.cfg['data']['num_workers']

        loader = DataLoader(self._uav_train, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)

        return loader
    

class UAVDataset(Dataset):
  """ Represents the UAV dataset. """

  def __init__(self, path_to_dataset: str, filenames: List[str], transformations: List[Transformation]):
    super().__init__()

    # get path to all RGB images
    self.path_to_images = path_to_dataset + ...

    self.image_files: List[str] = []
    for fname in os.listdir(self.path_to_images):
        if fname in filenames:
            self.image_files.append(fname)

    # get path to all ground-truth semantic masks
    self.path_to_annos = path_to_dataset + ...

    self.anno_files: List[str] = []
    for fname in os.listdir(self.path_to_annos):
        if fname in filenames:
            self.anno_files.append(fname)

    # specify image transformations
    self.img_to_tensor = transforms.ToTensor()
    self.transformations = transformations
    
    
    def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str]]:
        """ Get a single sample of the dataset. """

        # get the current image
        path_to_current_img: str = os.path.join(self.path_to_images, self.image_files[idx])
        img_pil = Image.open(path_to_current_img)  
        img = self.img_to_tensor(img_pil)

        # get the corresponding annotation
        path_to_current_anno: str = os.path.join(self.path_to_annos, self.anno_files[idx])
        anno = np.array(Image.open(path_to_current_anno))  
        anno = torch.Tensor(anno)

        # apply a set of transformations to the raw_image, image and anno
        for transformer in self.transformations:
            img_pil, img, anno = transformer(img_pil, img, anno)

        return {'img': img_pil, 'data': img, 'anno': anno, 'fname': self.image_files[idx]}



## 2. Model

The model should have as encoder ResNet50 if you want to use our pre-training code. Our methodology can work with any backbone, if you change something in our pre-training architecture be sure to include the same modifications here. The model needs a decoder, specialized for your downstream task - remember to check the dimensions of your encoder output before designing your decoder. 

Here an example:

In [None]:
class MyModel(pl.LightningModule):

    def __init__(self, in_size, dropout, ...): # pass all the information needed for your decoder
        super().__init__()
        self.in_size = in_size
        self.encoder = Encoder(in_size = self.in_size, dropout = dropout)
        self.decoder = Decoder(...)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = self.encoder(input)
        return self.decoder(output)
    
    # define your training, validation and test step
    def training_step(self,...)
    
    # define your loss
    def compute_loss(self,...)

    # define your optimization method
    def configure_optimizers(self,...):

class Encoder(nn.Module):

    def __init__(self, in_size, dropout, OS = 32, bn_d = 0.1, model = 'resnet50'):
        super().__init__()
        self.net = ResNet(in_size, OS, dropout, bn_d, model)
        self.pool = nn.AdaptiveAvgPool2d(1)

        # this are the standard sizes of the projector network after ResNet50
        sizes = [2048, 8192, 8192, 8192]
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

    def forward(self, input):
        # define your forward pass
        # you can use the whole encoder up to the final embedding
        # or you can run each layer individually until the depth you like
        skips = {}
        os = 1
        
        x, skips, os = self.net.run_layer(input, self.net.conv1, skips, os)
        x, skips, os = self.net.run_layer(x, self.net.bn1, skips, os)
        x, skips, os = self.net.run_layer(x, self.net.relu, skips, os)
        x, skips, os = self.net.run_layer(x, self.net.maxpool, skips, os)
        x, skips, os = ...
        return x


class Decoder(nn.Module):
    def __init__(self, ...):
        super().__init__()
        # define your decoder

    def forward(self, input):
        # define your decoder forward pass
        output = input
        return output


## 3. Config file

Use a config file to store all important variables for your network, as data path, batch size, type of architecture, details for the decoder. Usually the config file is a yaml file, you can find the config file we use for the pre-training in this repository and use it as example.

## 4. Main

Implement the main.py file. 

Remember to load the weights of the pre training as `model.load_state_dict(checkpoint['state_dict'])`. 
If you implement the model exactly as in our pre-training backbone everything should work easily. If this is not the case, the most common reason is that the keys of the two dictionaries do not match -- layers and parts of the architecture have different names. One way to solve the issue is to re-name the keys of the loaded dictionary to match yours. For example:

`real_checkpoint = {}
 for k in checkpoint['state_dict'].keys():
     new_string = k.replace('model.','').replace('net.','').replace('resnet50','net')
     real_checkpoint[new_string] = checkpoint['state_dict'][k]
`

which get a key `model.net.resnet50.layer` and transforms it into `net.layer`. 

Once you loaded the pre-trained weights, you can define everything else as usual and fit / test your model.