In [1]:
import os
import cv2
import sys
import zipfile
import traceback
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [None]:
import yaml
import joblib
import torch
import torch.nn as nn


def config():
    with open("../../config.yml", "r") as file:
        return yaml.safe_load(file)


def dump(value=None, filename=None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)


def load(filename=None):
    if filename is not None:
        return joblib.load(filename=filename)


class CustomException(Exception):
    def __init__(self, message=None):
        super().__init__(message)


def weight_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def device_init(device="cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    else:
        return torch.device("cpu")

In [None]:
class Loader:
    def __init__(
        self, image_path=None, channels=3, image_size=256, batch_size=4, split_size=0.20
    ):
        self.image_path = image_path
        self.channels = channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.actual = []
        self.target = []

    def unzip_folder(self):
        self.raw_data_path = config()["path"]["RAW_DATA_PATH"]

        if os.path.exists(self.raw_data_path):
            with zipfile.ZipFile(self.image_path, "r") as zip_file:
                zip_file.extractall(path=os.path.join(self.raw_data_path))

            print(
                "Unzip is done successfully and stoed in the path {}".format(
                    os.path.join(self.raw_data_path, "dataset")
                )
            )

        else:
            raise CustomException("Raw data path does not exist".capitalize())

    def transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def split_dataset(self, X, y):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=self.split_size, random_state=42
            )

            return {
                "X_train": X_train,
                "X_test": X_test,
                "y_train": y_train,
                "y_test": y_test,
            }
        else:
            raise CustomException("X and y should be list".capitalize())

    def extract_features(self):
        self.directory = os.path.join(config()["path"]["RAW_DATA_PATH"], "dataset")
        self.X = os.path.join(config()["path"]["RAW_DATA_PATH"], "dataset", "X")
        self.y = os.path.join(config()["path"]["RAW_DATA_PATH"], "dataset", "y")

        for image in tqdm(os.listdir(self.X)):
            if (image is not None) and (image in os.listdir(self.y)):
                self.imageX = os.path.join(self.X, image)
                self.imagey = os.path.join(self.y, image)

                self.imageX = cv2.imread(filename=self.imageX, flags=cv2.IMREAD_COLOR)
                self.imagey = cv2.imread(filename=self.imagey, flags=cv2.IMREAD_COLOR)

                self.imageX = cv2.cvtColor(self.imageX, cv2.COLOR_BGR2RGB)
                self.imagey = cv2.cvtColor(self.imagey, cv2.COLOR_BGR2RGB)

                self.imageX = Image.fromarray(self.imageX)
                self.imagey = Image.fromarray(self.imagey)

                self.imageX = self.transforms()(self.imageX)
                self.imagey = self.transforms()(self.imagey)

                self.actual.append(self.imageX)
                self.target.append(self.imagey)

        assert len(self.actual) == len(self.target)

        try:
            dataset = self.split_dataset(X=self.actual, y=self.target)

        except CustomException as e:
            print("An error occured: ", e)
            traceback.print_exc()

        except Exception as e:
            print("An error occured: ", e)
            traceback.print_exc()

        else:
            print("Feature extracted successfully".capitalize())

        return dataset

    def create_dataloader(self):
        self.dataset = self.extract_features()
        self.processed_data_path = config()["path"]["PROCESSED_DATA_PATH"]

        self.train_dataloader = DataLoader(
            dataset=list(zip(self.dataset["X_train"], list(self.dataset["y_train"]))),
            batch_size=self.batch_size,
            shuffle=True,
        )

        self.valid_dataloader = DataLoader(
            dataset=list(zip(self.dataset["X_test"], list(self.dataset["y_test"]))),
            batch_size=self.batch_size * self.batch_size,
            shuffle=True,
        )

        for value, filename in [
            (self.train_dataloader, "train_dataloader.pkl"),
            (self.valid_dataloader, "valid_dataloader.pkl"),
        ]:
            dump(value=value, filename=os.path.join(self.processed_data_path, filename))

        print(
            "DataLoader created successfully and stored in the path {}".capitalize().format(
                self.processed_data_path
            )
        )

    @staticmethod
    def plot_images():
        processed_data_path = config()["path"]["PROCESSED_DATA_PATH"]

        valid_dataloader = load(
            filename=os.path.join(processed_data_path, "valid_dataloader.pkl")
        )

        X, y = next(iter(valid_dataloader))

        number_of_rows = X.size(0) // 2
        number_of_columns = X.size(0) // number_of_rows

        plt.figure(figsize=(20, 10))

        for index, image in enumerate(X):
            imageX = image.permute(1, 2, 0).detach().numpy()
            imagey = y[index].permute(1, 2, 0).detach().numpy()

            imageX = (imageX - imageX.min()) / (imageX.max() - imageX.min())
            imagey = (imagey - imagey.min()) / (imagey.max() - imagey.min())

            plt.subplot(2 * number_of_rows, 2 * number_of_columns, 2 * index + 1)
            plt.title("actual".capitalize())
            plt.imshow(imageX)
            plt.axis("off")

            plt.subplot(2 * number_of_rows, 2 * number_of_columns, 2 * index + 2)
            plt.title("target".capitalize())
            plt.imshow(imagey)
            plt.axis("off")

        plt.tight_layout()
        plt.savefig(os.path.join(config()["path"]["FILES_PATH"], "images.png"))
        plt.show()

        print(
            "Images saved in the path {}".format(
                config()["path"]["FILES_PATH"]
            ).capitalize()
        )

    @staticmethod
    def details_dataset():
        processed_data_path = config()["path"]["PROCESSED_DATA_PATH"]

        train_dataloader = load(
            filename=os.path.join(processed_data_path, "train_dataloader.pkl")
        )
        valid_dataloader = load(
            filename=os.path.join(processed_data_path, "valid_dataloader.pkl")
        )

        trainX, trainY = next(iter(train_dataloader))
        validX, validY = next(iter(valid_dataloader))

        dataframe = pd.DataFrame(
            {
                "total_data(X)": [str(sum(X.size(0) for X, _ in train_dataloader))],
                "total_data(y)": [str(sum(X.size(0) for X, _ in valid_dataloader))],
                "total_data(X+y)": [
                    str(
                        sum(X.size(0) for X, _ in train_dataloader)
                        + sum(X.size(0) for X, _ in valid_dataloader)
                    )
                ],
                "train_image_size(X)": [str(trainX.size())],
                "valid_image_size(X)": [str(validX.size())],
            },
            index=["details".capitalize()],
        )

        dataframe.to_csv(os.path.join(config()["path"]["FILES_PATH"], "details.csv"))

        print(
            "dataset details saved in the path {}".format(
                config()["path"]["FILES_PATH"]
            ).capitalize()
        )


if __name__ == "__main__":
    loader = Loader(image_path="./data/raw/dataset1.zip")

    loader.unzip_folder()
    loader.create_dataloader()

    Loader.plot_images()
    Loader.details_dataset()

### Encoder

In [None]:
import os
import sys
import torch
import argparse
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, in_channels=3, out_channels=128):
        super(EncoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1

        self.encoder = self.encoder_block()

    def encoder_block(self):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=False,
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=self.out_channels),
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder(x)

        else:
            raise Exception("Input must be a tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Encoder Block for Variational Autoencoder".capitalize()
    )
    parser.add_argument(
        "--in_channels",
        type=int,
        default=3,
        help="Number of input channels".capitalize(),
    )
    parser.add_argument(
        "--out_channels",
        type=int,
        default=128,
        help="Number of output channels".capitalize(),
    )

    args = parser.parse_args()

    in_channels = 3
    out_channels = 128

    layers = []

    for _ in range(2):
        layers.append(EncoderBlock(in_channels=in_channels, out_channels=out_channels))
        in_channels = out_channels
        out_channels //= 2

    model = nn.Sequential(*layers)

    assert model(torch.randn(1, 3, 256, 256)).size() == (1, 64, 64, 64)

### Decoder

In [None]:
import os
import sys
import torch
import argparse
import torch.nn as nn

sys.path.append("src/")


class DecoderBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=128):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1

        self.decoder = self.decoder_block()

    def decoder_block(self):
        layers = [
            nn.ConvTranspose2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=False,
            ),
            nn.ReLU(inplace=True),
        ]

        return nn.Sequential(*layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.decoder(x)
        else:
            raise Exception("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Decoder Block for Variational Autoencoder".capitalize()
    )
    parser.add_argument(
        "--in_channels",
        type=int,
        default=64,
        help="Number of input channels".capitalize(),
    )
    parser.add_argument(
        "--out_channels",
        type=int,
        default=128,
        help="Number of output channels".capitalize(),
    )

    args = parser.parse_args()

    layers = []

    in_channels = 64
    out_channels = 128

    layers.append(DecoderBlock(in_channels=in_channels, out_channels=out_channels))

    in_channels = out_channels

    layers.append(
        nn.ConvTranspose2d(
            in_channels=in_channels, out_channels=3, kernel_size=4, stride=2, padding=1
        )
    )

    model = nn.Sequential(*layers)

    assert model(torch.randn(1, 64, 64, 64)).size() == (1, 3, 256, 256)

### VAE

In [None]:
import os
import sys
import torch
import argparse
import torch.nn as nn
from torchsummary import summary
from torchview import draw_graph

class VariationalAutoEncoder(nn.Module):
    def __init__(self, channels=3, image_size=256):
        super(VariationalAutoEncoder, self).__init__()

        self.in_channels = channels
        self.out_channels = image_size // 2

        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1

        self.encoder_layers = []
        self.decoder_layers = []

        for _ in range(2):
            self.encoder_layers.append(
                EncoderBlock(
                    in_channels=self.in_channels, out_channels=self.out_channels
                )
            )
            self.in_channels = self.out_channels
            self.out_channels //= 2

        self.encoder = nn.Sequential(*self.encoder_layers)

        self.mean = nn.Sequential(
            nn.Conv2d(
                in_channels=image_size // 4,
                out_channels=image_size // 4,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=False,
            )
        )

        self.log_variance = nn.Sequential(
            nn.Conv2d(
                in_channels=image_size // 4,
                out_channels=image_size // 4,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=False,
            )
        )

        self.out_channels = self.in_channels * 2

        self.decoder_layers.append(
            DecoderBlock(in_channels=self.in_channels, out_channels=self.out_channels)
        )
        self.decoder_layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=self.out_channels,
                    out_channels=channels,
                    kernel_size=self.kernel_size + 1,
                    stride=self.stride_size + 1,
                    padding=self.padding_size,
                )
            )
        )

        self.decoder = nn.Sequential(*self.decoder_layers)

    def reparameterization_trick(self, mean, log_variance):
        if isinstance(mean, torch.Tensor) and isinstance(log_variance, torch.Tensor):
            standard_deviation = torch.exp(0.5 * log_variance)
            eps = torch.randn((standard_deviation.size()))

            z = mean + eps * standard_deviation

            return z

        else:
            raise Exception("Input must be a tensor".capitalize())

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            encoder = self.encoder(x)

            mean = self.mean(encoder)
            log_variance = self.log_variance(encoder)

            try:
                z = self.reparameterization_trick(mean=mean, log_variance=log_variance)

            except Exception as e:
                print("An error occurred: {}".format(e))

            decoder = self.decoder(z)

            return decoder

        else:
            raise Exception("Input must be a tensor".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, VariationalAutoEncoder):
            return sum(params.numel() for params in model.parameters())

        else:
            raise Exception("Input must be a VariationalAutoEncoder".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Model for Variational Autoencoder".title()
    )
    parser.add_argument(
        "--channels",
        type=int,
        default=config()["VAE"]["channels"],
        help="Number of channels in the input image".title(),
    )
    parser.add_argument(
        "--image_size",
        type=int,
        default=config()["VAE"]["image_size"],
        help="Size of the input image".title(),
    )

    args = parser.parse_args()

    variational_autoencoder = VariationalAutoEncoder(
        channels=args.channels, image_size=args.image_size
    )

    assert variational_autoencoder(torch.randn(1, 3, 256, 256)).size() == (
        1,
        args.channels,
        args.image_size,
        args.image_size,
    )

    assert VariationalAutoEncoder.total_params(variational_autoencoder) == 348547

    print(summary(model=variational_autoencoder, input_size=(3, 256, 256)))

    draw_graph(
        model=variational_autoencoder,
        input_data=torch.randn(1, args.channels, args.image_size, args.image_size),
    ).visual_graph.render(
        filename=os.path.join(config()["path"]["FILES_PATH"], "VAE"), format="png"
    )

    print(
        "Model Architecture saved as VAE.png in the path {}".format(
            config()["path"]["FILES_PATH"]
        )
    )

### MSELoss

In [None]:
import sys
import torch
import argparse
import torch.nn as nn


class MSELoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(MSELoss, self).__init__()

        self.reduction = reduction

    def forward(self, pred, actual):
        if isinstance(pred, torch.Tensor) and isinstance(actual, torch.Tensor):
            self.loss = nn.MSELoss(reduction=self.reduction)
            return self.loss(pred, actual)

        else:
            raise Exception("Both inputs must be of type torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MSE Loss".capitalize())
    parser.add_argument(
        "--reduction", type=str, default="mean", help="reduction method".capitalize()
    )

    args = parser.parse_args()

    loss = MSELoss(reduction=args.reduction)

    predicted = torch.tensor([1.0, 0.0, 1.0, 0.0])
    actual = torch.tensor([1.0, 0.0, 1.0, 0.0])

    assert loss(predicted, actual) == (0.0)

### KLDiversance

In [None]:
import sys
import torch
import argparse
import torch.nn as nn


class KLDiversance(nn.Module):
    def __init__(self, name="KLDiversance"):
        super(KLDiversance, self).__init__()

        self.name = name

    def forward(self, mean, log_variance):
        if isinstance(mean, torch.Tensor) and isinstance(log_variance, torch.Tensor):
            return -0.5 * torch.sum(
                1 + log_variance - mean**2 - torch.exp(log_variance)
            )

        else:
            raise Exception("mean and log_variance must be torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="KLDiversance".title())
    parser.add_argument("--kl", action="store_true", help="KLDiversance".capitalize())

    mean = torch.randn(1, 64, 64, 64)
    log_variance = torch.randn(1, 64, 64, 64)

    loss = KLDiversance()

    assert type(loss(mean, log_variance)) == torch.Tensor

### Helper

In [None]:
import os
import sys
import torch
import traceback
import torch.optim as optim

def load_dataloader():
    processed_datapath = config()["path"]["PROCESSED_DATA_PATH"]

    if os.path.exists(processed_datapath):

        train_dataloader = load(
            filename=os.path.join(processed_datapath, "train_dataloader.pkl")
        )
        valid_dataloader = load(
            filename=os.path.join(processed_datapath, "valid_dataloader.pkl")
        )

        return {
            "train_dataloader": train_dataloader,
            "valid_dataloader": valid_dataloader,
        }

    else:
        raise CustomException("Processed data not found".capitalize())


def helpers(**kwargs):
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]
    lr = kwargs["lr"]
    beta1 = kwargs["beta1"]
    beta2 = kwargs["beta2"]
    momentum = kwargs["momentum"]

    channels = config()["VAE"]["channels"]
    image_size = config()["VAE"]["image_size"]
    
    assert image_size == config()["dataloader"]["image_size"]

    model = VariationalAutoEncoder(channels=channels, image_size=image_size)

    if adam:
        optimizer = optim.Adam(params=model.parameters(), lr=lr, betas=(beta1, beta2))

    if SGD:
        optimizer = optim.SGD(params=model.parameters(), lr=lr, momentum=momentum)

    criterion = MSELoss(reduction="mean")
    kl_diversance_loss = KLDiversance()

    try:
        dataloader = load_dataloader()

    except CustomException as e:
        print("An eeror is occured: ", e)
        traceback.print_exc()

    except Exception as e:
        print("An error is occured: ", e)
        traceback.print_exc()

    return {
        "train_dataloader": dataloader["train_dataloader"],
        "valid_dataloader": dataloader["valid_dataloader"],
        "model": model,
        "optimizer": optimizer,
        "criterion": criterion,
        "kl_diversance_loss": kl_diversance_loss,
    }


if __name__ == "__main__":
    init = helpers(adam=True, SGD=False, lr=0.001, beta1=0.9, beta2=0.999, momentum=0.9)

    assert init["train_dataloader"].__class__ == torch.utils.data.dataloader.DataLoader
    assert init["valid_dataloader"].__class__ == torch.utils.data.dataloader.DataLoader

    assert init["model"].__class__ == VariationalAutoEncoder
    assert init["optimizer"].__class__ == torch.optim.Adam

    assert init["criterion"].__class__ == MSELoss
    assert init["kl_diversance_loss"].__class__ == KLDiversance

### Unittest

In [None]:
import os
import sys
import torch
import unittest
import torch.nn as nn

class UnitTest(unittest.TestCase):
    def setUp(self):
        self.train_dataloader = load(
            filename=os.path.join(
                config()["path"]["PROCESSED_DATA_PATH"], "train_dataloader.pkl"
            )
        )
        self.valid_dataloader = load(
            filename=os.path.join(
                config()["path"]["PROCESSED_DATA_PATH"], "valid_dataloader.pkl"
            )
        )

    def test_dataloader(self):
        self.assertEqual(
            self.train_dataloader.__class__, torch.utils.data.dataloader.DataLoader
        )
        self.assertEqual(
            self.valid_dataloader.__class__, torch.utils.data.dataloader.DataLoader
        )

    def test_quantity_train_dataloader(self):
        self.assertEqual(sum(X.size(0) for X, _ in self.train_dataloader), 12)

    def test_quantity_valid_dataloader(self):
        self.assertEqual(sum(X.size(0) for X, _ in self.valid_dataloader), 6)

    def test_encoder(self):
        in_channels = 3
        out_channels = 128

        layers = []

        for _ in range(2):
            layers.append(
                EncoderBlock(in_channels=in_channels, out_channels=out_channels)
            )
            in_channels = out_channels
            out_channels //= 2

        model = nn.Sequential(*layers)

        self.assertEqual(
            model(torch.randn(1, 3, 256, 256)).size(), torch.Size([1, 64, 64, 64])
        )

    def test_decoder(self):
        layers = []

        in_channels = 64
        out_channels = 128

        layers.append(DecoderBlock(in_channels=in_channels, out_channels=out_channels))

        in_channels = out_channels

        layers.append(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=3,
                kernel_size=4,
                stride=2,
                padding=1,
            )
        )

        model = nn.Sequential(*layers)

        self.assertEqual(
            model(torch.randn(1, 64, 64, 64)).size(), torch.Size([1, 3, 256, 256])
        )

    def test_VAE(self):
        self.model = VariationalAutoEncoder()

        self.assertEqual(
            self.model(torch.randn(1, 3, 256, 256)).size(), torch.Size([1, 3, 256, 256])
        )
        self.assertIsInstance(self.model, VariationalAutoEncoder)

    def test_loss(self):
        self.init = helpers(
            adam=True, SGD=False, lr=0.001, beta1=0.9, beta2=0.999, momentum=0.9
        )

        self.assertEqual(self.init["criterion"].__class__, MSELoss)
        self.assertEqual(self.init["kl_diversance_loss"].__class__, KLDiversance)


if __name__ == "__main__":
    unittest.main()

#### Trainer

In [None]:
import os
import sys
import torch
import mlflow
import dagshub
import argparse
import warnings
import traceback
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
from dagshub import dagshub_logger
from torch.optim.lr_scheduler import StepLR
from torchvision.utils import save_image

load_dotenv()

warnings.filterwarnings("ignore")


class Trainer:
    def __init__(
        self,
        epochs=100,
        lr=0.001,
        beta1=0.9,
        beta2=0.999,
        momentum=0.9,
        weight_decay=0.0001,
        step_size=10,
        gamma=0.85,
        adam=True,
        SGD=False,
        device="cuda",
        verbose=True,
        lr_scheduler=False,
        weight_init=False,
        l1_regularization=False,
        l2_regularization=False,
        MLFlow=True,
    ):
        self.epochs = epochs
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.step_size = step_size
        self.gamma = gamma
        self.adam = adam
        self.SGD = SGD
        self.device = device
        self.verbose = verbose
        self.lr_scheduler = lr_scheduler
        self.weight_init = weight_init
        self.l1_regularization = l1_regularization
        self.l2_regularization = l2_regularization
        self.MLFlow = MLFlow

        self.init = helpers(
            adam=self.adam,
            SGD=self.SGD,
            lr=self.lr,
            beta1=self.beta1,
            beta2=self.beta2,
            momentum=self.momentum,
        )

        try:
            self.device = device_init(device=self.device)
        except Exception as e:
            raise CustomException(e, sys)

        self.train_dataloader = self.init["train_dataloader"]
        self.valid_dataloader = self.init["valid_dataloader"]

        self.model = self.init["model"].to(self.device)
        self.optimizer = self.init["optimizer"]

        self.criterion = self.init["criterion"]
        self.kl_divergence_loss = self.init["kl_diversance_loss"]

        assert (
            self.init["train_dataloader"].__class__
            == torch.utils.data.dataloader.DataLoader
        )
        assert (
            self.init["valid_dataloader"].__class__
            == torch.utils.data.dataloader.DataLoader
        )

        assert self.init["model"].__class__ == VariationalAutoEncoder
        assert self.init["optimizer"].__class__ == torch.optim.Adam

        assert self.init["criterion"].__class__ == MSELoss
        assert self.init["kl_diversance_loss"].__class__ == KLDivergence

        if self.weight_init:
            self.model.apply(weight_init)

        if self.lr_scheduler:
            self.scheduler = StepLR(
                optimizer=self.optimizer, step_size=self.step_size, gamma=self.gamma
            )

        self.loss = float("inf")
        self.history = {"train_loss": [], "valid_loss": []}

        os.getenv("MLFLOW_TRACKING_URI")
        os.getenv("MLFLOW_TRACKING_USERNAME")
        os.getenv("MLFLOW_TRACKING_PASSWORD")

        dagshub.init(
            repo_owner="atikul-islam-sajib", repo_name="VAE-Pytorch", mlflow=True
        )

        mlflow.set_experiment(experiment_name="Variational Auto Encoder".title())

    def l1_regularization_loss(self, model):
        if isinstance(model, VariationalAutoEncoder):
            return self.weight_decay * sum(
                torch.norm(params, 1) for params in model.parameters()
            )

        else:
            raise CustomException(
                "Model is not an instance of VariationalAutoEncoder", sys
            )

    def l2_regularization_loss(self, model):
        if isinstance(model, VariationalAutoEncoder):
            return self.weight_decay * sum(
                torch.norm(params, 2) for params in model.parameters()
            )

        else:
            raise CustomException(
                "Model is not an instance of VariationalAutoEncoder", sys
            )

    def update_model_loss(self, **kwargs):
        self.optimizer.zero_grad()

        X = kwargs["X"]
        y = kwargs["y"]

        predicted, mean, log_variance = self.model(X)

        criterion_loss = self.criterion(predicted, y)
        kl_divergence_loss = self.kl_divergence_loss(mean, log_variance)

        self.total_loss = criterion_loss + kl_divergence_loss

        if self.l1_regularization:
            self.total_loss += self.l1_regularization_loss(self.model)

        if self.l2_regularization:
            self.total_loss += self.l2_regularization_loss(self.model)

        self.total_loss.backward()
        self.optimizer.step()

        return self.total_loss.item()

    def show_progress(self, **kwargs):
        if self.verbose:
            print(
                "Epochs:[{}/{}] - train_loss: [{:.4f}] - valid_loss:{:.4f}".format(
                    kwargs["epoch"],
                    self.epochs,
                    kwargs["train_loss"],
                    kwargs["valid_loss"],
                )
            )
        else:
            print(
                "Epochs:[{}/{}] is completed".capitalize().format(
                    kwargs["epoch"], self.epochs
                )
            )

    def save_images(self, **kwargs):
        epoch = kwargs["epoch"]

        X, y = next(iter(self.train_dataloader))
        X = X.to(self.device)
        y = y.to(self.device)

        predicted, _, _ = self.model(X)
        if epoch % (self.epochs // 20) == 0:
            save_image(
                predicted,
                os.path.join(
                    config()["path"]["TRAIN_IMAGES_PATH"],
                    "train_image{}.png".format(epoch),
                ),
            )

    def saved_checkpoints(self, **kwargs):
        epoch = kwargs["epoch"]
        train_loss = kwargs["train_loss"]
        valid_loss = kwargs["valid_loss"]

        if self.loss > valid_loss:
            self.loss = valid_loss

            torch.save(
                {
                    "model": self.model.state_dict(),
                    "epoch": epoch,
                    "train_loss": train_loss,
                    "valid_loss": valid_loss,
                },
                os.path.join(config()["path"]["TEST_MODELS"], "best_model.pth"),
            )

        if epoch % (self.epochs // 20) == 0:
            torch.save(
                self.model.state_dict(),
                os.path.join(
                    config()["path"]["TRAIN_MODELS"], "model{}.pth".format(epoch)
                ),
            )

    def train(self):
        with mlflow.start_run(
            description="In machine learning, a variational autoencoder is an artificial neural network architecture introduced by Diederik P. Kingma and Max Welling. It belongs to the family of probabilistic graphical models and variational Bayesian methods."
        ) as run:
            for epoch in tqdm(range(self.epochs)):
                self.train_loss = []
                self.valid_loss = []

                for _, (X, y) in enumerate(self.train_dataloader):
                    X = X.to(self.device)
                    y = y.to(self.device)

                    self.train_loss.append(self.update_model_loss(X=X, y=y))

                for _, (X, y) in enumerate(self.valid_dataloader):
                    X = X.to(self.device)
                    y = y.to(self.device)

                    predicted, mean, log_variance = self.model(X)

                    predicted_loss = self.criterion(predicted, y)
                    kl_divergence_loss = self.kl_divergence_loss(mean, log_variance)

                    total_loss = predicted_loss + kl_divergence_loss

                    self.valid_loss.append(total_loss.item())

                if self.lr_scheduler:
                    self.scheduler.step()

                try:
                    self.show_progress(
                        epoch=epoch + 1,
                        train_loss=np.mean(self.train_loss),
                        valid_loss=np.mean(self.valid_loss),
                    )
                except Exception as e:
                    print("An error occured: {}".format(e))
                    traceback.print_exc()

                try:
                    self.save_images(epoch=epoch + 1)
                except Exception as e:
                    print("An error occured: {}".format(e))
                    traceback.print_exc()

                try:
                    self.saved_checkpoints(
                        epoch=epoch + 1,
                        train_loss=np.mean(self.train_loss),
                        valid_loss=np.mean(self.valid_loss),
                    )
                except Exception as e:
                    print("An error occured: {}".format(e))
                    traceback.print_exc()

                self.history["train_loss"].append(np.mean(self.train_loss))
                self.history["valid_loss"].append(np.mean(self.valid_loss))

                mlflow.log_params(
                    {
                        "epochs": self.epochs,
                        "lr": self.lr,
                        "beta1": self.beta1,
                        "beta2": self.beta2,
                        "momentum": self.momentum,
                        "weight_decay": self.weight_decay,
                        "step_size": self.step_size,
                        "gamma": self.gamma,
                        "adam": self.adam,
                        "SGD": self.SGD,
                        "device": self.device,
                        "lr_scheduler": self.lr_scheduler,
                        "weight_init": self.weight_init,
                        "l1_regularization": self.l1_regularization,
                        "l2_regularization": self.l2_regularization,
                        "verbose": self.verbose,
                    }
                )

                mlflow.log_metric(
                    key="train_loss", value=np.mean(self.train_loss), step=epoch + 1
                )
                mlflow.log_metric(
                    key="valid_loss", value=np.mean(self.valid_loss), step=epoch + 1
                )

            mlflow.pytorch.log_model(self.model, "model")

            dump(
                value=self.history,
                filename=os.path.join(
                    config()["path"]["TRAIN_HISTORY_PATH"], "history.pkl"
                ),
            )

        print(
            "Train image saved in the path {}".format(
                config()["path"]["TRAIN_IMAGES_PATH"]
            )
        )
        print(
            "Train and best models saved in the path {} and {}".format(
                config()["path"]["TRAIN_MODELS"], config()["path"]["TEST_MODELS"]
            )
        )
        print("To visualize the MLFlow user-interface, run the command: mlflow ui")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train the model for VAE".title())
    parser.add_argument(
        "--epochs",
        type=int,
        default=config()["trainer"]["epochs"],
        help="Number of epochs".capitalize(),
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=config()["trainer"]["lr"],
        help="Learning rate".capitalize(),
    )
    parser.add_argument(
        "--beta1",
        type=float,
        default=config()["trainer"]["beta1"],
        help="Beta1 for Adam optimizer".capitalize(),
    )
    parser.add_argument(
        "--beta2",
        type=float,
        default=config()["trainer"]["beta2"],
        help="Beta2 for Adam optimizer".capitalize(),
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=config()["trainer"]["momentum"],
        help="Momentum for SGD optimizer".capitalize(),
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=config()["trainer"]["weight_decay"],
        help="Weight decay for SGD optimizer".capitalize(),
    )
    parser.add_argument(
        "--step_size",
        type=int,
        default=config()["trainer"]["step_size"],
        help="Step size for learning rate scheduler".capitalize(),
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=config()["trainer"]["gamma"],
        help="Gamma for learning rate scheduler".capitalize(),
    )
    parser.add_argument(
        "--adam",
        type=bool,
        default=config()["trainer"]["adam"],
        help="Use Adam optimizer".capitalize(),
    )
    parser.add_argument(
        "--SGD",
        type=bool,
        default=config()["trainer"]["SGD"],
        help="Use SGD optimizer".capitalize(),
    )
    parser.add_argument(
        "--device",
        type=str,
        default=config()["trainer"]["device"],
        help="Device to use".capitalize(),
    )
    parser.add_argument(
        "--verbose",
        type=bool,
        default=config()["trainer"]["verbose"],
        help="Verbose mode".capitalize(),
    )
    parser.add_argument(
        "--lr_scheduler",
        type=bool,
        default=config()["trainer"]["lr_scheduler"],
        help="Use learning rate scheduler".capitalize(),
    )
    parser.add_argument(
        "--weight_init",
        type=bool,
        default=config()["trainer"]["weight_init"],
        help="Use weight initialization".capitalize(),
    )
    parser.add_argument(
        "--l1_regularization",
        type=bool,
        default=config()["trainer"]["l1_regularization"],
        help="Use L1 regularization".capitalize(),
    )
    parser.add_argument(
        "--l2_regularization",
        type=bool,
        default=config()["trainer"]["l2_regularization"],
        help="Use L2 regularization".capitalize(),
    )
    parser.add_argument(
        "--MLFlow",
        type=bool,
        default=config()["trainer"]["MLFlow"],
        help="Use MLFlow for tracking".capitalize(),
    )

    args = parser.parse_args()

    trainer = Trainer(
        epochs=args.epochs,
        lr=args.lr,
        beta1=args.beta1,
        beta2=args.beta2,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        step_size=args.step_size,
        gamma=args.gamma,
        adam=args.adam,
        SGD=args.SGD,
        device=args.device,
        lr_scheduler=args.lr_scheduler,
        weight_init=args.weight_init,
        l1_regularization=args.l1_regularization,
        l2_regularization=args.l2_regularization,
        verbose=args.verbose,
        MLFlow=args.MLFlow,
    )

    trainer.train()

### Tester


In [None]:
import os
import sys
import torch
import argparse
import matplotlib.pyplot as plt

class Tester:
    def __init__(self, model_path="best", device="cuda"):
        self.model_path = model_path
        self.device = device

        self.device = device_init(device=self.device)

    def select_the_model(self):
        if self.model_path == "best":
            best_model_path = os.path.join(
                config()["path"]["TEST_MODELS"], "best_model.pth"
            )

            return torch.load(best_model_path)["model"]
        else:
            best_model_path = self.model_path
            model_state = torch.load(best_model_path)

            return model_state

    def plot(self):
        valid_dataloader = load(
            filename=os.path.join(
                config()["path"]["PROCESSED_DATA_PATH"], "valid_dataloader.pkl"
            )
        )

        X, y = next(iter(valid_dataloader))
        X = X.to(self.device)
        y = y.to(self.device)

        predicted, _, _ = self.model(X)

        number_of_rows = (X.size(0) + 1) // 2
        number_of_columns = 2

        plt.figure(figsize=(15, 15))

        for index in range(X.size(0)):
            pred = predicted[index].permute(1, 2, 0).cpu().detach().numpy()
            actual = X[index].permute(1, 2, 0).cpu().detach().numpy()
            target = y[index].permute(1, 2, 0).cpu().detach().numpy()

            pred = (pred - pred.min()) / (pred.max() - pred.min())
            actual = (actual - actual.min()) / (actual.max() - actual.min())
            target = (target - target.min()) / (target.max() - target.min())

            plt.subplot(number_of_rows, number_of_columns * 3, 3 * index + 1)
            plt.imshow(actual)
            plt.axis("off")
            plt.title("Actual")

            plt.subplot(number_of_rows, number_of_columns * 3, 3 * index + 2)
            plt.imshow(pred)
            plt.axis("off")
            plt.title("Predicted")

            plt.subplot(number_of_rows, number_of_columns * 3, 3 * index + 3)
            plt.imshow(target)
            plt.axis("off")
            plt.title("Target")

        plt.tight_layout()
        plt.savefig(
            os.path.join(config()["path"]["VALID_IMAGES_PATH"], "test_result.png")
        )
        plt.show()

        print(
            "The test result is saved in the path {}".format(
                config()["path"]["VALID_IMAGES_PATH"]
            )
        )

    def test(self):
        self.model = VariationalAutoEncoder().to(self.device)
        self.model.load_state_dict(self.select_the_model())

        self.model.eval()
        self.plot()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Tester for Variational Auto Encoder".title()
    )
    parser.add_argument(
        "--model",
        type=str,
        default=config()["tester"]["model"],
        help="Path to the model to be tested".capitalize(),
    )
    parser.add_argument(
        "--device",
        type=str,
        default=config()["tester"]["device"],
        help="Device to be used".capitalize(),
    )

    args = parser.parse_args()

    tester = Tester(model_path=args.model, device=args.device)

    tester.test()