# Implementation Guidelines of Sample Code (Pytorch)

    See the annotations at every markdown blocks correspoding to each code blocks, and also # TODO annotations. :D

# Usage guideline of Jupyter Notebook (If needed)

    Installation   : https://jupyter.org/install  
    User Document  : https://jupyter-notebook.readthedocs.io/en/latest/user-documentation.html

# Test Environment (Recommended)

    In test time, we will evaluate the given codes from you with the following version of libraries.  
    So, it is highly recommended to use those packages with specific version below.

    test environment : pytorch

### Packages
    python   : 3.8.17  
    torch    : 2.0.1   
    skimage  : 0.21.0  
    cv2      : 4.8.0

# Import libraries (Do not change!)

In [1]:
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
import cv2
from torch.utils.data import DataLoader
from skimage import io
import pandas as pd
import matplotlib.pyplot as plt
import math
import copy
import time
import PIL
import pickle



  from .autonotebook import tqdm as notebook_tqdm


# Split dataset (Do not change!)

### Notice 1
    This function do split your dataset of 1000 classes into 10 groups of 100 each.    
    So, it is needed to be implemented just once at first to split your dataset for continual learning.   
    *Again, you dont need to use this function in every tranining time if you already split your dataset into 10 groups.

    Notice the annotation codes below. (You can see this codes in 'main' block.)

```python
        parser = argparse.ArgumentParser()   
        # Change this as 'False' after dividing your datsaet into 10 groups.
        parser.add_argument('--div_data',   default = True)  
        args = parser.parse_args(args=[])  
```

### Notice 2
    We reshapes all the input data size into constant 128x128.   
    Until further notification, use this constant size. 

```python
        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 130):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))  # resize image into 128 x 128 
                x_train.append(img)
```


In [2]:
def train_split(validation_num):
    # TODO : set dataset path
    # TODO : We recommends you to place your code and tranining dataset in the same location.
    
    dir = './Koh_Young_AI_data/'
    

    for div_idx in range(0, 10): # Div into 10 groups
        # Divide data 0-129 for training, 130-150 for validation.
        x_train = []
        x_valid = []
        y_train = []
        y_valid = []
        start   = 100*div_idx + 1
        end     = 100*div_idx + 101

        # Split input data.  
        for i in range(start, end):
            for img_idx in range(0, 150-validation_num):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_train.append(img)

            for img_idx in range(150-validation_num, 150):
                path = os.path.join(dir, str(i))
                path = path + '/' +str(img_idx)+'.png'
                img = io.imread(path)
                img = cv2.resize(img, (128, 128))
                x_valid.append(img)

        # Split corresponding output label data.
        for folder_idx in range(start, end):
            for img_idx in range(0, 150-validation_num):
                y_train.append(np.array([folder_idx]))
            for img_idx in range(150-validation_num, 150):
                y_valid.append(np.array([folder_idx]))

        # Convert list to numpy 
        x_train = np.array(x_train)
        y_train = np.array(y_train)
        x_valid = np.array(x_valid)
        y_valid = np.array(y_valid)

        # TODO : Define train data and valid data directory path.
        # TODO : Recommends not to change these directory paths. 
        train_save_dir = 'train_data'
        valid_save_dir = 'valid_data'
        if not os.path.exists(train_save_dir):
            os.makedirs(train_save_dir)

        if not os.path.exists(valid_save_dir):
            os.makedirs(valid_save_dir)

        # TODO : Save train/valid data
        np.save(f'./train_data/x_data_{div_idx+1}', x_train)
        np.save(f'./train_data/y_data_{div_idx+1}', y_train)
        np.save(f'./valid_data/x_data_{div_idx+1}', x_valid)
        np.save(f'./valid_data/y_data_{div_idx+1}', y_valid)

        print(f" ===================== Done in {div_idx} ===================== ")

# Define Dataloader (Do not change!)

    You can define your own dataloader with API of torch.utils.data.Dataset.  
    This can usually help you to reduce computational burden when dealing with high dimensional data, such as images.  

    reference url : https://pytorch.org/tutorials/beginner/basics/data_tutorial.html


In [3]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x_data, y_data, device):
        self.x_data = x_data
        self.y_data = y_data
        self.device = device

    def __getitem__(self, idx):
        # .transpose(0, 2) : width x height x channel (0, 1, 2) ---> channel x width x height (2, 0, 1).
        # .squeeze(0) : add extra dimension at axis 0.
        x = torch.FloatTensor(self.x_data[idx]).transpose(0, 2)
        y = torch.LongTensor(self.y_data[idx]).squeeze(0)
        return x, y
        
    def __len__(self):
        return len(self.x_data)

def load_train_data(class_num):
    # TODO : set 'class_path' with your train_data path.
    class_path  = f'./train_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)
    return x_data, y_data

def load_valid_data(class_num):
    # TODO : set 'class_path' with your valid_data path.
    class_path  = f'./valid_data/'
    x_data_path = class_path + 'x_data_' + str(class_num+1) + '.npy'
    y_data_path = class_path + 'y_data_' + str(class_num+1) + '.npy'
    x_data      = np.load(x_data_path, allow_pickle=True)
    y_data      = np.load(y_data_path, allow_pickle=True)

    # return processed data. 
    return x_data, y_data

# Define tranining function (You can modify this part!)

    Set your model with train mode as 'model.train()'.   

    useful reference : https://wikidocs.net/195118

In [4]:
def train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion,train_dataset,indexs):
    """
    model             : your customized model 
    x_train           : input data for tranining
    optimizer         : optimizer
    num_epoches       : number of iteration
    train_data_loader : dataloder of training dataset
    indexs            : train seq
    """
    
    if indexs== 0:
        num_epochs=100
    else:
        num_epochs=50
        
    if indexs>0:
        num_classess=(indexs+1)*100
        criterion = nn.CrossEntropyLoss()
        #pretrained_dict = ewc.model.state_dict()
        pretrained_dict = model.state_dict()
        '''new_ewc = ElasticWeightConsolidation(ViT(
            image_size = 128,
            patch_size = 16,
            num_classes = num_classess,
            dim = 64,
            depth = 6,
            heads = 6,
            mlp_dim = 128,
            dropout = 0.1,
            emb_dropout = 0.1
        ).to(device), crit=criterion, lr=2e-4)'''
        new_model=ViT(
            image_size = 128,
            patch_size = 16,
            num_classes = num_classess,
            dim = 256,
            depth = 6,
            heads = 6,
            mlp_dim = 256,
            dropout = 0.1,
            emb_dropout = 0.1
        ).to(device)
        #new_ewc.register_ewc_params(train_dataset, 16, 100)
        #new_model_dict = new_ewc.model.state_dict()
        new_model_dict = new_model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict}
        pretrained_dict.pop('mlp_head.0.weight')
        pretrained_dict.pop('mlp_head.0.bias')
        pretrained_dict.pop('mlp_head.1.weight')
        pretrained_dict.pop('mlp_head.1.bias')
        #new_ewc.model.update(pretrained_dict)
        new_model.load_state_dict(pretrained_dict,strict=False)
        del model
        #old_model=model
        model=new_model
    #ewc.model.train()                     # Set train mode. 
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        acc      = 0 # Accuracy
        avg_cost = 0 # Average Cost 
        for x, y in train_data_loader:
            #loss=ewc.forward_backward_update(x.to(device), y.to(device))
            #out = ewc.model(x.to(device))            # Inference (batch,classes)
            out = model(x.to(device))
            _, preds = torch.max(out, 1)         # preds : Predicted class
            cost = criterion(out, y.to(device))  # Calculates errors between true label and predictions with respect to your criterion. 

            # Optimize processs.
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()

            avg_cost += cost # Average cost 
            acc      += torch.sum(preds.detach().cpu() == (y.data).detach().cpu()) # Accuracy
        print(f" # - EPOCHS {epoch + 1} / {num_epochs} | AvgCost {avg_cost} | Accuracy : {acc/len(x_train)} - #")
    #ewc.register_ewc_params(train_dataset, 16, 100)
    #pretrained_dict = ewc.model.state_dict()
    #print(pretrained_dict)
    #exit(1)
    # Return trainded model and accuracy. 
    return model, acc/len(x_train)

# Define validataion function (Do not change!)

    And eval mode as 'model.eval()' or 'model.train(False)'.

In [5]:
def validation(model, x_valid, valid_data_loader, criterion):
    """
    model             : your customized model 
    x_vallid          : input data for validation
    valid_data_loader : dataloder of valid dataset 
    """
    
    model.eval() # Set eval mode
    
    acc = 0
    
    for x, y in valid_data_loader:
        out = model(x.data.to(device))
        _, preds = torch.max(out, 1)
        cost  = criterion(out, y.to(device))
        acc += torch.sum(preds.detach().cpu() == (y.data).detach().cpu())
    print(f" # - ValidCost {cost} | Accuracy : {acc / len(x_valid)} - #")

    # Return Accuracy 
    return acc/len(x_valid)

# Define your model and hyperparameter (You can modify this part!)

    Here is the pivotal part of your competition.
    We gives a simple CNN model, for example. 
    Go make your own model!         

In [6]:
class ElasticWeightConsolidation:

    def __init__(self, model, crit, lr=0.001, weight=1000000):
        self.model = model
        self.weight = weight
        self.crit = crit
        self.optimizer = optim.Adam(self.model.parameters(), lr)
        #self.make_fisher()
        
    def _update_mean_params(self):
        for param_name, param in self.model.named_parameters():
            _buff_param_name = param_name.replace('.', '__')
            self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())

    def _update_fisher_params(self, current_ds, batch_size, num_batch):
        dl = DataLoader(current_ds, batch_size, shuffle=True)
        log_liklihoods = []
        for i, (inputs, target) in enumerate(dl):
            if i > num_batch:
                break
            output = torch.nn.functional.log_softmax(self.model(inputs.to(device)), dim=1)
            log_liklihoods.append(output[:, target])
        log_likelihood = torch.cat(log_liklihoods).mean()
        grad_log_liklihood = torch.autograd.grad(log_likelihood, self.model.parameters())
        _buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
        for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
            self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)
    def make_fisher(self,):
        _buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
        for _buff_param_name in _buff_param_names:
            self.model.register_buffer(_buff_param_name+'_estimated_fisher', torch.tensor(0))
        for param_name, param in self.model.named_parameters():
            _buff_param_name = param_name.replace('.', '__')
            self.model.register_buffer(_buff_param_name+'_estimated_mean', torch.tensor(0))
            
    def register_ewc_params(self, dataset, batch_size, num_batches):
        self._update_fisher_params(dataset, batch_size, num_batches)
        self._update_mean_params()

    def _compute_consolidation_loss(self, weight):
        try:
            losses = []
            for param_name, param in self.model.named_parameters():
                _buff_param_name = param_name.replace('.', '__')
                estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
                estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
                losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
            return (weight / 2) * sum(losses)
        except AttributeError:
            return 0

    def forward_backward_update(self, inputs, target):
        output = self.model(inputs)
        loss = self._compute_consolidation_loss(self.weight) + self.crit(output, target.to(device))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss
    
    def save(self, filename):
        torch.save(self.model, filename)

    def load(self, filename):
        self.model = torch.load(filename)

In [7]:
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn


# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class LSA(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads #64*16

        self.heads = heads
        self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))

        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) #output이 그럼 1024*3

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        batch_size=x.shape[0]
        patch_size=x.shape[1]
        x = self.norm(x)
        #print(x.shape)
        qkv = self.to_qkv(x).chunk(3, dim = -1) #to_qkv 하면 batch,patch,1024*3이 나옴 patch=257임
        #test=torch.chunk(qkv,3,dim=-1)

        #print(test)
        q, k, v = map(lambda t: torch.reshape(t, (batch_size,self.heads,patch_size,-1)), qkv)
        #q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) #batch, head, patch, 나머지

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()

        mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
        mask_value = -torch.finfo(dots.dtype).max
        dots = dots.masked_fill(mask, mask_value)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)

        out = torch.reshape(out, (batch_size,patch_size,-1))
        #out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class SPT(nn.Module):
    def __init__(self, *, dim, patch_size, channels = 3):
        super().__init__()
        self.patch_dim = patch_size * patch_size * 5 * channels

        self.to_patch_tokens = nn.Sequential(
            #Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),#h=256//16
            nn.LayerNorm(self.patch_dim),
            nn.Linear(self.patch_dim, dim)
        )#입력(batch,원래3인데 12개 추가해서 15, 256,256)

    def forward(self, x):
        batch_size=x.shape[0]
        shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
        shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))

        x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
        #print(np.shape(torch.asarray(x_with_shifts).detach().numpy()))
        x_with_shifts=torch.reshape(x_with_shifts,(batch_size,-1,self.patch_dim))
        return self.to_patch_tokens(x_with_shifts)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes),
            nn.Softmax(dim=-1)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        #print(x.shape)
        cls_tokens = self.cls_token.repeat(b,1,1)
        #cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)

        x = torch.cat((cls_tokens, x), dim=1)
        #print(np.shape(torch.asarray(x).detach().numpy()))
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        #print(x.shape)
        #exit(1)
        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [8]:
class SimpleCNN(nn.Module):
    def __init__(self, in_channel, num_class):
        super().__init__()
        self.num_class = num_class
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channel, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Linear(64*16*16, self.num_class)

    def forward(self, x):
        # rescaling
        x = x/255.0

        # TODO : Convolution layer
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)

        # TODO : Reshape for fully-connected layer
        out = out.view(-1, 64*16*16)

        # TODO : Fully-connected layer
        out = self.fc1(out)

        # TODO : final-pycharm output - 1000 (class num)
        return out

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# TODO : Define your model

model=ViT(
    image_size = 128,
    patch_size = 16,
    num_classes = 100,
    dim = 256,
    depth = 6,
    heads = 6,
    mlp_dim = 256,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)
# TODO : Set your hyperparameters
batch_size        = 16
learning_rate     = 0.0001
num_epochs        = 15
optimizer         = optim.Adam(model.parameters(), lr=learning_rate)
random_seed       = 555
validation_num    = 20 # for 150 images for class, the number for validation data
criterion = nn.CrossEntropyLoss() # Define criterion. 

'''ewc = ElasticWeightConsolidation(ViT(
    image_size = 128,
    patch_size = 16,
    num_classes = 100,
    dim = 64,
    depth = 6,
    heads = 6,
    mlp_dim = 128,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device), crit=criterion, lr=1e-4)'''

'ewc = ElasticWeightConsolidation(ViT(\n    image_size = 128,\n    patch_size = 16,\n    num_classes = 100,\n    dim = 64,\n    depth = 6,\n    heads = 6,\n    mlp_dim = 128,\n    dropout = 0.1,\n    emb_dropout = 0.1\n).to(device), crit=criterion, lr=1e-4)'

# Incremental Learning. (Do not change!)

### WARNING:
    The training and validation datasets each SHOULD BE prepared properly beforehand.  
    If not, the submitted code from you will be immediately rejected.

In [10]:

"""  
--div_data  : split your data or not.   
"""
parser = argparse.ArgumentParser()  
parser.add_argument('--div_data',   default = False)  # Change this with 'False' after dividing your datsaet into 10 groups.
args = parser.parse_args(args=[])  

# TODO : Saving tranined model in this location. Don't change this path. 
save_dir = './result'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# TODO : Seed
random.seed(random_seed)
torch.manual_seed(random_seed)

# TODO : Split dataset according to argument '--div_data'
if args.div_data == True:
    train_split(validation_num)
else:
    pass


""" 
1. training      : train each 100 classes sequentailly with respect to 1000 output class. 
    trainining class === 1-100 -> 101-200 -> 201-300 -> 301-400 -> ... -> 901-1000
    
2. validation    : validate each trained model.
    validation class === 1-100 -> 1-200 -> 1-300 -> ... -> 1-1000
    
3. model save    : saves each trained model.                
"""

for div_idx in range(10):

    # TODO : Load your train and validation data
    x_train, y_train = load_train_data(div_idx)
    x_valid, y_valid = load_valid_data(div_idx)

    """
        in case of tranining 1  -100 classes, validate on 1-100 classes
        in case of tranining 101-200 classes, validate on 1-200 classes
        in case of tranining 201-300 classes, validate on 1-300 classes
        and so on...            
    """
    
    if div_idx == 0:
        x_val_tmp = x_valid
        y_val_tmp = y_valid
    else:
        x_val_tmp = np.concatenate((x_val_tmp, x_valid), axis = 0)
        y_val_tmp = np.concatenate((y_val_tmp, y_valid), axis = 0)
        x_valid   = x_val_tmp
        y_valid   = y_val_tmp

    # TODO : let the label starts from 0 to match the output index of model prediction. (Currently the label starts from 1.)
    y_train = y_train - 1
    y_valid = y_valid - 1

    # TODO : Define dataset and dataloader
    train_dataset     = CustomDataset(x_train, y_train, device)
    valid_dataset     = CustomDataset(x_valid, y_valid, device)
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_data_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # TODO : train and validate
    trained_model, acc_train = train_model(model, x_train, optimizer, num_epochs, train_data_loader, criterion,train_dataset,div_idx)
    acc_valid                = validation(trained_model, x_valid, valid_data_loader, criterion)

    #if div_idx == 9:
    MODEL_SAVE_FOLDER_PATH = './model_save/'
    num=str(div_idx)
    if not os.path.exists(MODEL_SAVE_FOLDER_PATH):
        os.mkdir(MODEL_SAVE_FOLDER_PATH)        
    model_path = MODEL_SAVE_FOLDER_PATH + 'continual_model'+num+'.pt'
    # TODO : save trained model in 'save_model_path'
    torch.save(trained_model.state_dict(), model_path)

    print(f'{str(div_idx)} Iteration Done.')

 # - EPOCHS 1 / 100 | AvgCost 3643.383056640625 | Accuracy : 0.09300699084997177 - #
 # - EPOCHS 2 / 100 | AvgCost 3581.782958984375 | Accuracy : 0.16993007063865662 - #
 # - EPOCHS 3 / 100 | AvgCost 3528.75390625 | Accuracy : 0.24009324610233307 - #
 # - EPOCHS 4 / 100 | AvgCost 3472.34326171875 | Accuracy : 0.31351980566978455 - #
 # - EPOCHS 5 / 100 | AvgCost 3414.3408203125 | Accuracy : 0.3852369785308838 - #
 # - EPOCHS 6 / 100 | AvgCost 3389.180908203125 | Accuracy : 0.41546231508255005 - #
 # - EPOCHS 7 / 100 | AvgCost 3364.094482421875 | Accuracy : 0.44630923867225647 - #
 # - EPOCHS 8 / 100 | AvgCost 3349.8515625 | Accuracy : 0.46309247612953186 - #
 # - EPOCHS 9 / 100 | AvgCost 3326.056396484375 | Accuracy : 0.49494948983192444 - #
 # - EPOCHS 10 / 100 | AvgCost 3297.87353515625 | Accuracy : 0.5305361151695251 - #
 # - EPOCHS 11 / 100 | AvgCost 3244.88916015625 | Accuracy : 0.6010878086090088 - #
 # - EPOCHS 12 / 100 | AvgCost 3202.164306640625 | Accuracy : 0.653224527835846 

 # - EPOCHS 99 / 100 | AvgCost 2927.109375 | Accuracy : 0.9828282594680786 - #
 # - EPOCHS 100 / 100 | AvgCost 2925.700439453125 | Accuracy : 0.9840714931488037 - #
 # - ValidCost 3.6591711044311523 | Accuracy : 0.8868687152862549 - #
0 Iteration Done.
 # - EPOCHS 1 / 50 | AvgCost 4259.7578125 | Accuracy : 0.008080808445811272 - #
 # - EPOCHS 2 / 50 | AvgCost 4259.76171875 | Accuracy : 0.00901320856064558 - #
 # - EPOCHS 3 / 50 | AvgCost 4259.76806640625 | Accuracy : 0.007692307699471712 - #
 # - EPOCHS 4 / 50 | AvgCost 4259.75341796875 | Accuracy : 0.008158507756888866 - #
 # - EPOCHS 5 / 50 | AvgCost 4259.77197265625 | Accuracy : 0.008780108764767647 - #
 # - EPOCHS 6 / 50 | AvgCost 4259.7509765625 | Accuracy : 0.009401709772646427 - #
 # - EPOCHS 7 / 50 | AvgCost 4259.7724609375 | Accuracy : 0.008236207999289036 - #
 # - EPOCHS 8 / 50 | AvgCost 4259.7783203125 | Accuracy : 0.009479409083724022 - #
 # - EPOCHS 9 / 50 | AvgCost 4259.74365234375 | Accuracy : 0.008935509249567986 - #
 #


KeyboardInterrupt



In [None]:
print(model)
model_dict = model.state_dict()
# remove the keys corresponing to the linear layer in the pretrained_dict
#model_dict.pop(mlp_head.0.weight)
#model_dict.pop(mlp_head.0.bias)
#model_dict.pop(mlp_head.1.weight)
#model_dict.pop(mlp_head.1.bias)
# now update the model dict with pretrained dict
print(model_dict)