<a href="https://colab.research.google.com/github/John-Katis/ConversationalAI-Website/blob/main/Domain%20Incremental%20ResNet18%20Vanilla.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CLRS ResNet18 Continual Learning - Class Incremental**


### GDrive and Libraries instantiation

In [None]:
"""
    Mounting google Drive
"""


from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
"""
    Installing required packages for loading the .tif files and the 
    avalanche distribution
"""


!pip install avalanche-lib
!pip install rasterio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
"""
    Importing the necessary libraries used in this notebook
"""


import os
import sys
import glob
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple

import rasterio
from rasterio.plot import reshape_as_image

import numpy as np

import torch
from torch.utils.data import Dataset
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torchvision import transforms
from torchvision.models import resnet18

from avalanche.benchmarks.utils import DatasetFolder
from avalanche.training import Naive
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
    forgetting_metrics, 
    accuracy_metrics,
    loss_metrics, 
    timing_metrics, 
    cpu_usage_metrics, 
    StreamConfusionMatrix,
    disk_usage_metrics, 
    gpu_usage_metrics
)
from avalanche.logging import (
    InteractiveLogger, 
    TextLogger, 
    TensorboardLogger
)

### **Loading the data - Scenario Creation**

* Defining list and dictionaires of the classes

* Defining the transformations to be used

* Creating the Continual Learning scenarios

In [None]:
"""
    Defining lists and dictionaries for the classes.

    + IDX_CLASS_LABELS: 
                Dictionary that maps integer class label keys to the 
                adjacent class label values

    + CLASSES: 
                A list containing all the class names

    + CLASS_IDX_LABELS: 
                Dictionary that maps string class label keys to the
                adjacent integer class label values
"""


IDX_CLASS_LABELS = {
    0: 'airport',
    1: 'bare-land',
    2: 'beach',
    3: 'bridge',
    4: 'commercial',
    5: 'desert',
    6: 'farmland',
    7: 'forest',
    8: 'golf-course',
    9: 'highway',
    10: 'industrial',
    11: 'meadow',
    12: 'mountain',
    13: 'overpass',
    14: 'park',
    15: 'parking',
    16: 'playground',
    17: 'port',
    18: 'railway',
    19: 'railway-station',
    20: 'residential',
    21: 'river',
    22: 'runway',
    23: 'stadium',
    24: 'storage-tank'
}

CLASSES = [
    'airport', 
    'bare-land', 
    'beach', 
    'bridge', 
    'commercial', 
    'desert',
    'farmland', 
    'forest',
    'golf-course',
    'highway',
    'industrial',
    'meadow',
    'mountain',
    'overpass',
    'park',
    'parking',
    'playground',
    'port',
    'railway',
    'railway-station',
    'residential',
    'river',
    'runway',
    'stadium',
    'storage-tank'
]

CLASS_IDX_LABELS = dict()
for key, val in IDX_CLASS_LABELS.items():
  CLASS_IDX_LABELS[val] = key

In [None]:
"""
    Callable functions that map integer to string class label values and vise
    versa.

    + encode_label(label):
            - Input: a string value class label
            - Output: an integer value class label
    
    + decode_target(target):
            - Input: an integer value class label
            - Output: a string value class label
"""


## Take a class label and return the index
def encode_label(label):
    return CLASS_IDX_LABELS[label] 
     

## Take in idx and return the class name
def decode_target(target):
    return IDX_CLASS_LABELS[target]

In [None]:
"""
    Defining the transformations for the images in a callable function.
    This will then be passed to avalanche's "filelist_benchmark" scenario
    generator to transform the images during training.
"""


torch.manual_seed(10)

data_transformation = transforms.Compose([
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

In [None]:
# FOR CLASS INCREMENTAL - use DatasetFolder or dataset_benchmark
# DatasetFolder refers to tne PyTorch implementation of the Dataset class. 
# Overwrite this with the specifications below (paths and targets are needed - 
# avalanche then resolves the rest)
# Then, once Datasets are created for the classes, pass them to dataset_benchmark
# as a list of datasets

In [None]:
# create a Dataset class with a paths attribute and a targets attribute (lists with equal size)

In [None]:
# dataset_benchmark can be used as well for Domain Incremental

In [None]:
"""
    Dataset class for defining the Domain Incremental scenarios.

    + Attributes:
      - root: a string that defines the root path to the dataset and scenarios
      - paths: a string the is the file name of the text file containing the
               samples for each scenario
      - transform: OPTIONAL a callable function that transforms the images
      - target_transform: OPTIONAL a callable function that transforms the 
                          image labels
    
    + Properties:
      - targets: returns a list of target labels for the dataset
      - data: returns a list of all the images in the dataset
        -- The mapping from targets to data is made with an index. Each image
           and the corresponding label exist in the same index in both lists

    + Methods:
      - _load_data(): PRIVATE is called in the constructor method and loads
                      all the data from the 'paths' file in the local 
                      runtime environment
      - __getitem__(index: int): returns a tuple of (image, label) given an
                                 index from the dataloader
"""


def DC_CLRS_DATASET(Dataset):


  @property
  def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets


  @property
  def test_labels(self):
      warnings.warn("test_labels has been renamed targets")
      return self.targets


  @property
  def train_data(self):
      warnings.warn("train_data has been renamed data")
      return self.data


  @property
  def test_data(self):
      warnings.warn("test_data has been renamed data")
      return self.data


  def __init__(
      self, 
      root: str,
      paths: str,
      transform: Optional[Callable] = None,
      target_transform: Optional[Callable] = None
  ):
    super.__init__(
        root, 
        transform=transform, 
        target_transform=target_transform
    )
    self.paths = paths

    self.data, self.targets = self._load_data()


  def _load_data(self):
    data = list
    targets = list

    with open(os.path.join(self.root, self.paths), 'r') as f:
      ds_paths = f.readlines()

    for path in ds_paths:
      targets.append(path.split('\t')[1].replace('\n', ''))

      image_path = path.split('\t')[0]
      with rasterio.open(os.path.join(self.root, image_path), 'r') as img:
        data[targets].append(img.read([1,2,3]))

    return data, targets

    
  def __getitem__(self, index: int):
    
    image, label = self.data[index], int(self.targets[index])

    # do the transforms

    return image, label