In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision import disable_beta_transforms_warning
disable_beta_transforms_warning()
import torchvision.transforms.v2 as tf
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader
from tqdm import tqdm
import yaml
import os
import matplotlib.pyplot as plt
import random

workers = os.cpu_count()
pu = 'cuda' if torch.cuda.is_available() else 'cpu'

assert pu == 'cuda', 'Connect to the GPU'

print(f'using {pu}, number of workers: {workers}')



using cuda, number of workers: 8


In [9]:
home_folder = '/content/gdrive/MyDrive'
WORKSPACE = f'{home_folder}/Skin_Diseases_Detection'
DATASET_FOLDER = f"{home_folder}/skdi_dataset"
TRAINING_FOLDER = f'{DATASET_FOLDER}/train_dir'
TESTING_FOLDER = f'{DATASET_FOLDER}/base_dir'
CHECKPOINT_FOLDER = f'{WORKSPACE}/checkpoints'
STATUS_FOLDER = f'{WORKSPACE}/status'
PLOT_FOLDER = f'{WORKSPACE}/plots'
PARAM_FOLDER = f'{WORKSPACE}/param'
CONFIG_FOLDER = f'{WORKSPACE}/config'
MODEL_FOLDER = f'{WORKSPACE}/model'

class Dataset:
    def __init__(self, train_trans, test_trans, train_batch_size = 32, val_rat = 0.2):
        self._ds = ImageFolder(TRAINING_FOLDER, train_trans)
        self.test_ds = ImageFolder(TESTING_FOLDER, test_trans)
        ds_size = len(self._ds)
        val_size = int(val_rat*ds_size)
        train_size = len(self._ds) - val_size

        self.train_ds, self.val_ds = Subset(self._ds, range(train_size)), Subset(self._ds, range(train_size, ds_size))
        self.train_dl = DataLoader(self.train_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)
        self.test_dl = DataLoader(self.test_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)
        self.val_dl = DataLoader(self.val_ds, batch_size=train_batch_size, shuffle=True, num_workers=workers)

def make_cnn(dataset: ImageFolder, hid_layers = [64, 64],
            act_fn='relu', max_pool = None, pooling_after_layers = 2, dropout = 0.2, batch_norm=True,
            conv_layers=[[32, 3, 1],
                         [16, 3, 1]]):
    
    img = dataset.__getitem__(0)[0]
    input_shape = img.shape
    num_of_classes = len(dataset.classes)
    layers = []
    activation_fun = {'relu': nn.ReLU(), 'softplus':nn.Softplus(), 'tanh':nn.Tanh(), 'elu': nn.ELU()}

    assert pooling_after_layers < len(conv_layers), 'exceeding the number conv layers..'

    in_chann, inp_h, inp_w = input_shape
    for ind, conv in enumerate(conv_layers):
        out_chann, filter_size, stride = conv
        layers.append(nn.Conv2d(in_chann, out_chann, filter_size, stride))
        if batch_norm:
            layers.append(nn.BatchNorm2d(out_chann))
        layers.append(activation_fun[act_fn])

        out_h = (inp_h - filter_size)//stride + 1
        out_w = (inp_w - filter_size)//stride + 1
        inp_h = out_h
        inp_w = out_w

        if max_pool is not None and ((ind+1) % pooling_after_layers == 0 or ind == (len(conv_layers) - 1)):
            layers.append(nn.MaxPool2d(max_pool[0], max_pool[1]))
            out_h = (inp_h - max_pool[0])//max_pool[1] + 1
            out_w = (inp_w - max_pool[0])//max_pool[1] + 1
            inp_h = out_h
            inp_w = out_w
        in_chann = out_chann

    layers.append(nn.Flatten())
    layers.append(nn.Linear(inp_h*inp_w*in_chann, hid_layers[0]))
    layers.append(activation_fun[act_fn])
    if len(hid_layers) > 1:
        dim_pairs = zip(hid_layers[:-1], hid_layers[1:])
        for in_dim, out_dim in list(dim_pairs):
            layers.append(nn.Linear(in_dim, out_dim))
            if dropout is not None:
                layers.append(nn.Dropout(p=dropout))
            layers.append(activation_fun[act_fn])

    layers.append(nn.Linear(hid_layers[-1], num_of_classes))
    return nn.Sequential(*layers)

def convert(**kwargs):
    return kwargs

class Utils:
    def __init__(self):
        self.model = None
        self.model_file = ''
        self.plot_file = ''
        self.min_loss = 0

    def read_file(self, path):
        file = open(path, 'r')
        file.seek(0)
        info = file.readline()
        file.close()
        return info

    def write_file(self, path, content):
        mode = 'w'
        if path == self.plot_file:
            mode = '+a'
        file = open(path, mode=mode)
        file.write(content)
        file.close()

    def create_file(self, path):
        with open(path, 'w') as file:
            pass
        file.close()

    def create_checkpoint_file(self, num):
        path = f'{CHECKPOINT_FOLDER}/checkpoint_{num}.pth'
        file = open(path, 'w')
        file.close()
        return path
    
    def save_config(self, args: dict):
        if not os.path.exists(self.config_file):
            self.create_file(self.config_file)
        with open(self.config_file, 'w') as file:
            yaml.safe_dump(args, file)
        file.close()

    def check_status_file(self):
        if not os.path.exists(self.status_file):
            self.create_file(self.status_file)
        checkpath = self.read_file(self.status_file)
        epoch = 0
        if checkpath != '':
            epoch = self.load_checkpoint(checkpath)
            file = open(self.plot_file, 'r')
            lines = file.readlines()
            file = open(self.plot_file, 'w')
            file.writelines(lines[:epoch+1])
            file.close()
        else:
            file = open(self.plot_file, 'w')
            file.close()
            self.write_file(self.plot_file,'Train_loss,Train_acc,Valid_loss,Valid_acc\n')
            self.model.train()
        return epoch

    def write_plot_data(self, data:list):
        str_data = ','.join(map(str, data))
        self.write_file(self.plot_file, f'{str_data}\n')

    def save_checkpoint(self, epoch, checkpath):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optim.state_dict(),
            'epoch': epoch
        }
        file = open(self.status_file, 'w')
        file.write(checkpath)
        file.close()
        torch.save(checkpoint, checkpath)
        print('checkpoint saved..')
    
    def load_checkpoint(self, checkpath):
        print('loading checkpoint..')
        checkpoint = torch.load(checkpath)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optim.load_state_dict(checkpoint['optim_state_dict'])
        self.model.train()
        print('checkpoint loaded...')
        return checkpoint['epoch']
    
    def save_check_interval(self, epoch, interval=50):
        if not(epoch % interval) and epoch > 0:
            checkpath = self.create_checkpoint_file(epoch)
            self.save_checkpoint(epoch, checkpath)
    
    def load_model(self):
        print('loading model...')
        self.model.load_state_dict(torch.load(self.model_file))
        self.model.eval()
        print('model loaded...')

    def save_model(self):
        torch.save(self.model.state_dict(), self.model_file)
        print('model saved...')

    def save_best_model(self, param, acc_param=True):
        if acc_param:
           param_ = max(param, self.param)
        else:
            param_ = min(param, self.param)
        self.param = param_
        self.write_file(self.param_file, f'{param_}')
        self.save_model()

    def check_param_file(self):
        if os.path.exists(self.param_file):
            param = float(self.read_file(self.param_file))
        else:
            self.create_file(self.param_file)
            param = -1000.0
            self.write_file(self.param_file, f'{param}')
        return param
    

class skdi_detector(Utils):
    def __init__(self,params):
        self.params = params
        self.dataset = Dataset(train_trans=params.train_trans, test_trans=params.test_trans,
                               train_batch_size=params.batch_size, val_rat=params.val_rat)
        self.loss_fn = CrossEntropyLoss()
        self.name = params.name
        self.model_file = f'{MODEL_FOLDER}/{self.name}_model.pth'
        self.status_file = f'{STATUS_FOLDER}/{self.name}_status.txt'
        self.plot_file = f'{PLOT_FOLDER}/{self.name}_plot.txt'
        self.param_file = f'{PARAM_FOLDER}/{self.name}_param.txt'
        self.config_file = f'{CONFIG_FOLDER}/{self.name}_config.yaml'
        self.param = self.check_param_file()
        self.metric_param = params.metric_param
        self.clip_grad = params.clip_grad
        self.acc_param = False
        if self.metric_param in ['train_acc', 'val_acc']:
            self.acc_param = True

    def train_step(self):
        train_loss, train_acc = 0.0, 0.0
        for _, (x, y) in enumerate(self.dataset.train_dl):
            x, y = x.to(pu), y.to(pu)
            y_pred_logits = self.model(x)

            loss = self.loss_fn(y_pred_logits, y)
            train_loss += loss.item()

            self.optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), self.clip_grad)
            self.optim.step()

            y_pred = torch.argmax(torch.softmax(y_pred_logits, dim=-1), dim=-1)
            train_acc += (y_pred == y).sum().item()/len(y)
        
        train_loss /= len(self.dataset.train_dl)
        train_acc /= len(self.dataset.train_dl)

        return train_loss, train_acc
    
    def validate_step(self):
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for _, (x, y) in enumerate(self.dataset.val_dl):
                x, y = x.to(pu), y.to(pu)
                y_pred_logits = self.model(x)

                loss = self.loss_fn(y_pred_logits, y)
                val_loss += loss.item()

                y_pred = torch.argmax(torch.softmax(y_pred_logits, dim=-1), dim=-1)
                val_acc += (y_pred == y).sum().item()/len(y)
        
        val_loss /= len(self.dataset.val_dl)
        val_acc /= len(self.dataset.val_dl)
        return val_loss, val_acc
    
    def create_model(self):
        self.model = make_cnn(dataset=self.dataset._ds, hid_layers=self.params.hid_layers, act_fn=self.params.act_fn,
                              max_pool=self.params.max_pool, pooling_after_layers=self.params.pool_after_layers,
                              batch_norm=self.params.batch_norm, conv_layers=self.params.conv_layers, dropout=self.params.dropout).to(pu)
        
        self.optim = Adam(self.model.parameters(), lr=self.params.lr, eps=1e-6, weight_decay=1e-5)
        print(f'Model: {self.model}')
        print(f'Number of classes: {self.dataset._ds.classes}')
        print(f'Input image size: {self.dataset._ds.__getitem__(0)[0][0].shape}')
        
    def plot_images(self):
        fig, axes = plt.subplots(1, 3, figsize=(10, 10))
        for i in range(3):
            ind = random.randint(0, len(self.dataset.train_ds)-1)
            img = self.dataset.train_ds.__getitem__(ind)[0]
            img = img.numpy()
            img = img.transpose((1, 2, 0)) if img.shape[0] == 3 else img
            axes.flat[i].imshow(img)
            axes.flat[i].axis('off')

    def train(self):
        epochs = self.params.epochs
        epoch = 0
        epoch = self.check_status_file()
        print(f'training for {epochs} epochs....')
        for ep in tqdm(range(epoch, epochs+1)):
            train_loss, train_acc = self.train_step()
            val_loss, val_acc = self.validate_step()

            metric_param = {'train_loss': train_loss, 'train_acc': train_acc,
                            'val_loss': val_loss, 'val_acc': val_acc}
            
            print(f'epochs: {ep}\t{train_loss = :.4f}\t{train_acc = :.4f}\t{val_loss = :.4f}\t{val_acc = :.4f}')
            self.write_plot_data([train_loss, train_acc, val_loss, val_acc])
            self.save_check_interval(epoch=ep, interval=1)
            self.save_best_model(acc_param=self.acc_param, param=metric_param[self.metric_param])
        
        print('Finished Training....')
    
class Params:
    def __init__(self):
        self.name = 'model'
        self.conv_layers = [[128, 3, 1],
                            [128, 3, 1],
                            [64, 3, 1],
                            [64, 3, 1],
                            [32, 3, 1],
                            [32, 3, 1],
                            [16, 3, 1],
                            [16, 3, 1]]
        self.hid_layers = [1024, 512, 256]
        self.max_pool = [2, 2]
        self.pool_after_layers = 2
        self.act_fn = 'relu'
        self.lr = 1e-5
        self.epochs = 100
        self.clip_grad = 0.3
        self.metric_param = 'val_acc'
        self.batch_size = 16
        self.batch_norm = True
        self.dropout = 0.2
        self.test_trans = tf.Compose([
            tf.Resize(size = (256,256)),
            tf.ToTensor()
        ])
        self.train_trans = tf.Compose([
            tf.Resize(size=(256, 256)),
            tf.RandomHorizontalFlip(p=0.8),
            tf.RandomVerticalFlip(),
            tf.ToTensor()
        ])
        self.val_rat = 0.01

In [None]:
params = Params()

agent = skdi_detector(params)

agent.create_model()
print(f'training dataset size: {len(agent.dataset.train_ds)}')
print(f'validation dataset size: {len(agent.dataset.val_ds)}')


In [None]:
torch.cuda.empty_cache()
agent.train()