In [None]:
from os import environ

home_folder = environ.get('HOME')
WORKSPACE = f'{home_folder}/Skin_Diseases_Detection'
DATASET_FOLDER = f"{home_folder}/skdi_dataset"
TRAINING_FOLDER = f'{DATASET_FOLDER}/Training'
TESTING_FOLDER = f'{DATASET_FOLDER}/Testing'
CHECKPOINT_FOLDER = f'{WORKSPACE}/checkpoints'
STATUS_FOLDER = f'{WORKSPACE}/status'
PLOT_FOLDER = f'{WORKSPACE}/plots'
PARAM_FOLDER = f'{WORKSPACE}/param'
CONFIG_FOLDER = f'{WORKSPACE}/config'

from paths import TRAINING_FOLDER, TESTING_FOLDER
import torchvision.transforms.v2 as tf
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader
import os

workers = os.cpu_count()

class Dataset:
    def __init__(self, train_trans, test_trans, train_batch_size = 32, val_rat = 0.2):
        self._ds = ImageFolder(TRAINING_FOLDER, train_trans)
        self.test_ds = ImageFolder(TESTING_FOLDER, test_trans)
        ds_size = len(self._ds)
        val_size = int(val_rat*ds_size)
        train_size = len(self._ds) - val_size

        self.train_ds, self.val_ds = Subset(self._ds, range(train_size)), Subset(self._ds, range(train_size, ds_size))
        self.train_dl = DataLoader(self.train_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)
        self.test_dl = DataLoader(self.test_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)
        self.val_dl = DataLoader(self.val_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)

import torch.nn as nn
from torchvision.datasets import ImageFolder


def make_cnn(dataset: ImageFolder, hid_layers = [64, 64],
            act_fn='relu', max_pool = None, pooling_after_layers = 2,
            conv_layers=[[32, 3, 1],
                         [16, 3, 1]]):
    
    img = dataset.__getitem__(0)[0]
    input_shape = img.shape
    num_of_classes = len(dataset.classes)
    layers = []
    activation_fun = {'relu': nn.ReLU(), 'softplus':nn.Softplus(), 'tanh':nn.Tanh(), 'elu': nn.ELU()}

    assert pooling_after_layers < len(conv_layers), 'exceeding the number conv layers..'

    in_chann, inp_h, inp_w = input_shape
    for ind, conv in enumerate(conv_layers):
        out_chann, filter_size, stride = conv
        layers.append(nn.Conv2d(in_chann, out_chann, filter_size, stride))
        layers.append(activation_fun[act_fn])

        out_h = (inp_h - filter_size)//stride + 1
        out_w = (inp_w - filter_size)//stride + 1
        inp_h = out_h
        inp_w = out_w

        if max_pool is not None and ((ind+1) % pooling_after_layers == 0 or ind == (len(conv_layers) - 1)):
            layers.append(nn.MaxPool2d(max_pool[0], max_pool[1]))
            out_h = (inp_h - max_pool[0])//max_pool[1] + 1
            out_w = (inp_w - max_pool[0])//max_pool[1] + 1
            inp_h = out_h
            inp_w = out_w
        in_chann = out_chann

    layers.append(nn.Flatten())
    layers.append(nn.Linear(inp_h*inp_w*in_chann, hid_layers[0]))
    layers.append(activation_fun[act_fn])

    
    if len(hid_layers) > 1:
        dim_pairs = zip(hid_layers[:-1], hid_layers[1:])
        for in_dim, out_dim in list(dim_pairs):
            layers.append(nn.Linear(in_dim, out_dim))
            layers.append(activation_fun[act_fn])

    layers.append(nn.Linear(hid_layers[-1], num_of_classes))
    return nn.Sequential(*layers)

from torchvision import disable_beta_transforms_warning
disable_beta_transforms_warning()
import torchvision.transforms.v2 as tf
from utils import convert

import torch
import os
from paths import CHECKPOINT_FOLDER
import yaml

def convert(**kwargs):
    return kwargs

class Utils:
    def __init__(self):
        self.model = None
        self.model_file = ''
        self.plot_file = ''
        self.min_loss = 0

    def read_file(self, path):
        file = open(path, 'r')
        file.seek(0)
        info = file.readline()
        file.close()
        return info

    def write_file(self, path, content):
        mode = 'w'
        if path == self.plot_file:
            mode = '+a'
        file = open(path, mode=mode)
        file.write(content)
        file.close()

    def create_file(self, path):
        with open(path, 'w') as file:
            pass
        file.close()

    def create_checkpoint_file(self, num):
        path = f'{CHECKPOINT_FOLDER}/checkpoint_{num}.pth'
        file = open(path, 'w')
        file.close()
        return path
    
    def save_config(self, args: dict):
        if not os.path.exists(self.config_file):
            self.create_file(self.config_file)
        with open(self.config_file, 'w') as file:
            yaml.safe_dump(args, file)
        file.close()


    def check_status_file(self):
        if not os.path.exists(self.status_file):
            self.create_file(self.status_file)
        checkpath = self.read_file(self.status_file)
        epoch = 0
        if checkpath != '':
            epoch = self.load_checkpoint(checkpath)
            file = open(self.plot_file, 'r')
            lines = file.readlines()
            file = open(self.plot_file, 'w')
            file.writelines(lines[:epoch+1])
            file.close()
        else:
            file = open(self.plot_file, 'w')
            file.close()
            self.write_file(self.plot_file,'Train_loss,Train_acc,Valid_loss,Valid_acc\n')
            self.model.train()
        return epoch

    def write_plot_data(self, data:list):
        str_data = ','.join(map(str, data))
        self.write_file(self.plot_file, f'{str_data}\n')

    def save_checkpoint(self, epoch, checkpath):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optim.state_dict(),
            'epoch': epoch
        }
        file = open(self.status_file, 'w')
        file.write(checkpath)
        file.close()
        torch.save(checkpoint, checkpath)
        print('checkpoint saved..')
    
    def load_checkpoint(self, checkpath):
        print('loading checkpoint..')
        checkpoint = torch.load(checkpath)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optim.load_state_dict(checkpoint['optim_state_dict'])
        self.model.train()
        print('checkpoint loaded...')
        return checkpoint['epoch']
    
    def save_check_interval(self, epoch, interval=50):
        if not(epoch % interval) and epoch > 0:
            checkpath = self.create_checkpoint_file(epoch)
            self.save_checkpoint(epoch, checkpath)
    
    def load_model(self):
        print('loading model...')
        self.model.load_state_dict(torch.load(self.model_file))
        self.model.eval()
        print('model loaded...')

    def save_model(self):
        torch.save(self.model.state_dict(), self.model_file)
        print('model saved...')

    def save_best_model(self, param, acc_param=True):
        if acc_param:
           param_ = max(param, self.param)
        else:
            param_ = min(param, self.param)
        self.param = param_
        self.write_file(self.param_file, f'{param_}')
        self.save_model()

    def check_param_file(self):
        if os.path.exists(self.param_file):
            param = float(self.read_file(self.param_file))
        else:
            self.create_file(self.param_file)
            param = -1000.0
            self.write_file(self.param_file, f'{param}')
        return param


class Params:
    def __init__(self):
        self.name = 'model'
        self.conv_layers = [[128, 3, 1],
                            [64, 3, 1],
                            [32, 3, 1],
                            [16, 3, 1]]
        self.hid_layers = [1024, 512]
        self.max_pool = [2, 2]
        self.pool_after_layers = 2
        self.act_fn = 'relu'
        self.lr = 1e-6
        self.epochs = 150
        self.clip_grad = 0.5
        self.metric_param = 'val_acc'
        self.batch_size = 32
        self.test_trans = tf.Compose([
            tf.Resize(size = (256,256))
        ])
        self.train_trans = tf.Compose([
            tf.Resize(size=(256, 256)),
            tf.ToImageTensor(),
            tf.ConvertImageDtype()
        ])
        self.val_rat = 0.2

