# Implementation Guidelines of Sample Code (Pytorch)

    See the annotations at every markdown blocks correspoding to each code blocks, and also # TODO annotations. :D

# Usage guideline of Jupyter Notebook (If needed)

    Installation   : https://jupyter.org/install  
    User Document  : https://jupyter-notebook.readthedocs.io/en/latest/user-documentation.html

# Test Environment (Recommended)

    In test time, we will evaluate the given codes from you with the following version of libraries.  
    So, it is highly recommended to use those packages with specific version below.

    test environment : pytorch

### Packages
    python   : 3.8.17  
    torch    : 2.0.1   
    skimage  : 0.21.0  
    cv2      : 4.8.0

# Import libraries (Do not change!)

In [1]:
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import cv2
from torch.utils.data import DataLoader
from skimage import io
import pandas as pd
import matplotlib.pyplot as plt
import math
import copy
import time
import PIL
import pickle



  from .autonotebook import tqdm as notebook_tqdm


# Split dataset (Do not change!)

### Notice 1
    This function do split your dataset of 1000 classes into 10 groups of 100 each.    
    So, it is needed to be implemented just once at first to split your dataset for continual learning.   
    *Again, you dont need to use this function in every tranining time if you already split your dataset into 10 groups.

    Notice the annotation codes below. (You can see this codes in 'main' block.)

```python
        parser = argparse.ArgumentParser()   
        # Change this as 'False' after dividing your datsaet into 10 groups.
        parser.add_argument('--div_data',   default = True)  
        args = parser.parse_args(args=[])  
```

### Notice 2
    We reshapes all the input data size into constant 128x128.   
    Until further notification, use this constant size. 

```python
        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 130):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))  # resize image into 128 x 128 
                x_train.append(img)
```


In [2]:
def train_split(validation_num):
    # TODO : set dataset path
    # TODO : We recommends you to place your code and tranining dataset in the same location.
    
    dir = './Koh_Young_AI_data/'
    

    for div_idx in range(0, 10): # Div into 10 groups
        # Divide data 0-129 for training, 130-150 for validation.
        x_train = []
        x_valid = []
        y_train = []
        y_valid = []
        start   = 100*div_idx + 1
        end     = 100*div_idx + 101

        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 150-validation_num):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_train.append(img)

            for img_idx in range(150-validation_num, 150):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_valid.append(img)

        # Split corresponding output label data.
        for folder_idx in range(start, end):
            for img_idx in range(0, 150-validation_num):
                y_train.append(np.array([folder_idx]))
            for img_idx in range(150-validation_num, 150):
                y_valid.append(np.array([folder_idx]))

        # Convert list to numpy 
        x_train = np.array(x_train)
        y_train = np.array(y_train)
        x_valid = np.array(x_valid)
        y_valid = np.array(y_valid)

        # TODO : Define train data and valid data directory path.
        # TODO : Recommends not to change these directory paths. 
        train_save_dir = 'train_data'
        valid_save_dir = 'valid_data'
        if not os.path.exists(train_save_dir):
            os.makedirs(train_save_dir)

        if not os.path.exists(valid_save_dir):
            os.makedirs(valid_save_dir)

        # TODO : Save train/valid data
        np.save(f'./train_data/x_data_{div_idx+1}', x_train)
        np.save(f'./train_data/y_data_{div_idx+1}', y_train)
        np.save(f'./valid_data/x_data_{div_idx+1}', x_valid)
        np.save(f'./valid_data/y_data_{div_idx+1}', y_valid)

        print(f" ===================== Done in {div_idx} ===================== ")

# Define Dataloader (Do not change!)

    You can define your own dataloader with API of torch.utils.data.Dataset.  
    This can usually help you to reduce computational burden when dealing with high dimensional data, such as images.  

    reference url : https://pytorch.org/tutorials/beginner/basics/data_tutorial.html


In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x_data, y_data, device):
        self.x_data = x_data
        self.y_data = y_data
        self.device = device

    def __getitem__(self, idx):
        # .transpose(0, 2) : width x height x channel (0, 1, 2) ---> channel x width x height (2, 0, 1).
        # .squeeze(0) : add extra dimension at axis 0.
        x = torch.FloatTensor(self.x_data[idx]).transpose(0, 2)
        y = torch.LongTensor(self.y_data[idx]).squeeze(0)
        return x, y
        
    def __len__(self):
        return len(self.x_data)

def load_train_data(class_num):
    # TODO : set 'class_path' with your train_data path.
    class_path  = f'./train_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)
    return x_data, y_data

def load_valid_data(class_num):
    # TODO : set 'class_path' with your valid_data path.
    class_path  = f'./valid_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)

    # return processed data. 
    return x_data, y_data

# Define tranining function (You can modify this part!)

    Set your model with train mode as 'model.train()'.   

    useful reference : https://wikidocs.net/195118

In [4]:
def train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion,train_dataset,indexs):
    """
    model             : your customized model 
    x_train           : input data for tranining
    optimizer         : optimizer
    num_epoches       : number of iteration
    train_data_loader : dataloder of training dataset
    indexs            : train seq
    """
    
    if indexs== 0:
        num_epochs=20
    else:
        num_epochs=10
        
    ewc.model.train()                     # Set train mode. 
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        acc      = 0 # Accuracy
        avg_cost = 0 # Average Cost 
        for x, y in train_data_loader:
            loss,out=ewc.forward_backward_update(x.to(device), y.to(device))
            #out = ewc.model(x.to(device))            # Inference (batch,classes)
            #out = model(x.to(device))
            _, preds = torch.max(out, 1)         # preds : Predicted class
            #cost = criterion(out, y.to(device))  # Calculates errors between true label and predictions with respect to your criterion. 

            # Optimize processs.
            #optimizer.zero_grad()
            #cost.backward()
            #optimizer.step()

            avg_cost += loss # Average cost 
            acc      += torch.sum(preds.detach().cpu() == (y.data).detach().cpu()) # Accuracy
        print(f" # - EPOCHS {epoch + 1} / {num_epochs} | AvgCost {avg_cost} | Accuracy : {acc/len(x_train)} - #")
    ewc.register_ewc_params(train_dataset, 32, 20)
    #pretrained_dict = ewc.model.state_dict()
    #print(pretrained_dict)
    #exit(1)
    # Return trainded model and accuracy. 
    return ewc.model, acc/len(x_train)

# Define validataion function (Do not change!)

    And eval mode as 'model.eval()' or 'model.train(False)'.

In [5]:
def validation(model, x_valid, valid_data_loader, criterion):
    """
    model             : your customized model 
    x_vallid          : input data for validation
    valid_data_loader : dataloder of valid dataset 
    """
    
    model.eval() # Set eval mode
    
    acc = 0
    
    for x, y in valid_data_loader:
        out = model(x.data.to(device))
        _, preds = torch.max(out, 1)
        cost  = criterion(out, y.to(device))
        acc += torch.sum(preds.detach().cpu() == (y.data).detach().cpu())
    print(f" # - ValidCost {cost} | Accuracy : {acc / len(x_valid)} - #")

    # Return Accuracy 
    return acc/len(x_valid)

# Define your model and hyperparameter (You can modify this part!)

    Here is the pivotal part of your competition.
    We gives a simple CNN model, for example. 
    Go make your own model!         

In [6]:
class ElasticWeightConsolidation:

    def __init__(self, model, crit, lr=0.001, weight=40000):
        self.model = model
        self.weight = weight
        self.crit = crit
        self.optimizer = optim.Adam(self.model.parameters(), lr)
        #self.make_fisher()
        
    def _update_mean_params(self):
        for param_name, param in self.model.named_parameters():
            _buff_param_name = param_name.replace('.', '__')
            self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())

    def _update_fisher_params(self, current_ds, batch_size, num_batch):
        dl = DataLoader(current_ds, batch_size, shuffle=True)
        log_liklihoods = []
        for i, (inputs, target) in enumerate(dl):
            if i > num_batch:
                break
            output = torch.nn.functional.log_softmax(self.model(inputs.to(device)), dim=1)
            log_liklihoods.append(output[:, target])
        log_likelihood = torch.cat(log_liklihoods).mean()
        grad_log_liklihood = torch.autograd.grad(log_likelihood, self.model.parameters())
        _buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
        for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
            self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)

    def register_ewc_params(self, dataset, batch_size, num_batches):
        self._update_fisher_params(dataset, batch_size, num_batches)
        self._update_mean_params()

    def _compute_consolidation_loss(self, weight):
        try:
            losses = []
            for param_name, param in self.model.named_parameters():
                _buff_param_name = param_name.replace('.', '__')
                estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
                estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
                losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
            return weight * sum(losses)
        except AttributeError:
            return 0

    def forward_backward_update(self, inputs, target):
        output = self.model(inputs)
        loss = self._compute_consolidation_loss(self.weight) + self.crit(output, target.to(device))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss,output
    
    def save(self, filename):
        torch.save(self.model, filename)

    def load(self, filename):
        self.model = torch.load(filename)

In [7]:
class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        self.relu = nn.ReLU()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        return x

class ResNet(nn.Module):
    def __init__(self, block, num_block, num_classes=1000, init_weights=True):
        super().__init__()

        self.in_channels=32

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.conv2_x = self._make_layer(block, 32, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 64, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 128, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 256, num_block[3], 2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256 * block.expansion, num_classes)

        # weights inittialization
        if init_weights:
            self._initialize_weights()

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self,x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        x = self.conv3_x(output)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    # define weight initialization function
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def resnet50():
    return ResNet(BottleNeck, [3, 8, 6, 6])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = resnet50().to(device)

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# TODO : Define your model
batch_size        = 16
learning_rate     = 0.001
num_epochs        = 15
optimizer         = optim.Adam(model.parameters(), lr=learning_rate)
random_seed       = 555
validation_num    = 20 # for 150 images for class, the number for validation data
criterion = nn.CrossEntropyLoss() # Define criterion. 

ewc = ElasticWeightConsolidation(model, crit=criterion, lr=5e-5)

# Incremental Learning. (Do not change!)

### WARNING:
    The training and validation datasets each SHOULD BE prepared properly beforehand.  
    If not, the submitted code from you will be immediately rejected.

In [None]:

"""  
--div_data  : split your data or not.   
"""
parser = argparse.ArgumentParser()  
parser.add_argument('--div_data',   default = False)  # Change this with 'False' after dividing your datsaet into 10 groups.
args = parser.parse_args(args=[])  

# TODO : Saving tranined model in this location. Don't change this path. 
save_dir = './result'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# TODO : Seed
random.seed(random_seed)
torch.manual_seed(random_seed)

# TODO : Split dataset according to argument '--div_data'
if args.div_data == True:
    train_split(validation_num)
else:
    pass


""" 
1. training      : train each 100 classes sequentailly with respect to 1000 output class. 
    trainining class === 1-100 -> 101-200 -> 201-300 -> 301-400 -> ... -> 901-1000
    
2. validation    : validate each trained model.
    validation class === 1-100 -> 1-200 -> 1-300 -> ... -> 1-1000
    
3. model save    : saves each trained model.                
"""

for div_idx in range(10):

    # TODO : Load your train and validation data
    x_train, y_train = load_train_data(div_idx)
    x_valid, y_valid = load_valid_data(div_idx)

    """
        in case of tranining 1  -100 classes, validate on 1-100 classes
        in case of tranining 101-200 classes, validate on 1-200 classes
        in case of tranining 201-300 classes, validate on 1-300 classes
        and so on...            
    """
    
    if div_idx == 0:
        x_val_tmp = x_valid
        y_val_tmp = y_valid
    else:
        x_val_tmp = np.concatenate((x_val_tmp, x_valid), axis = 0)
        y_val_tmp = np.concatenate((y_val_tmp, y_valid), axis = 0)
        x_valid   = x_val_tmp
        y_valid   = y_val_tmp

    # TODO : let the label starts from 0 to match the output index of model prediction. (Currently the label starts from 1.)
    y_train = y_train - 1
    y_valid = y_valid - 1

    # TODO : Define dataset and dataloader
    train_dataset     = CustomDataset(x_train, y_train, device)
    valid_dataset     = CustomDataset(x_valid, y_valid, device)
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_data_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # TODO : train and validate
    trained_model, acc_train = train_model(ewc, x_train, optimizer, num_epochs, train_data_loader, criterion,train_dataset,div_idx)
    acc_valid                = validation(trained_model, x_valid, valid_data_loader, criterion)

    #if div_idx == 9:
    MODEL_SAVE_FOLDER_PATH = './model_save/'
    num=str(div_idx)
    if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
        os.mkdir(MODEL_SAVE_FOLDER_PATH)        
    model_path = MODEL_SAVE_FOLDER_PATH + 'continual_model_ewc'+num+'.pt'
    # TODO : save trained model in 'save_model_path'
    torch.save(trained_model.state_dict(), model_path)

    print(f'{str(div_idx)} Iteration Done.')

 # - EPOCHS 1 / 20 | AvgCost 3151.29833984375 | Accuracy : 0.04685314744710922 - #
 # - EPOCHS 2 / 20 | AvgCost 2304.6611328125 | Accuracy : 0.12525252997875214 - #
 # - EPOCHS 3 / 20 | AvgCost 1961.093505859375 | Accuracy : 0.19316239655017853 - #
 # - EPOCHS 4 / 20 | AvgCost 1701.47802734375 | Accuracy : 0.2777000665664673 - #
 # - EPOCHS 5 / 20 | AvgCost 1443.014404296875 | Accuracy : 0.37668997049331665 - #
 # - EPOCHS 6 / 20 | AvgCost 1223.1766357421875 | Accuracy : 0.4668997526168823 - #
 # - EPOCHS 7 / 20 | AvgCost 1005.3316650390625 | Accuracy : 0.556332528591156 - #
 # - EPOCHS 8 / 20 | AvgCost 790.9276123046875 | Accuracy : 0.6406371593475342 - #
 # - EPOCHS 9 / 20 | AvgCost 603.7872924804688 | Accuracy : 0.7193472981452942 - #
 # - EPOCHS 10 / 20 | AvgCost 463.1643981933594 | Accuracy : 0.7867909669876099 - #
 # - EPOCHS 11 / 20 | AvgCost 363.6328430175781 | Accuracy : 0.8275835514068604 - #
 # - EPOCHS 12 / 20 | AvgCost 282.5030822753906 | Accuracy : 0.8677544593811035 - #
