# Image Inpainting problem

In [9]:
import os
import sys
import torch
import torch.nn as nn
import torchvision.transforms as T
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models
sys.path.append(os.path.abspath(".."))
from ml_utils import parse_config, vizualization, decode_dataset

In [1]:
VALID_JSON = {
    "pretrain":
    {
        "type":"object",
        "properties":
        {
            "flag":
            {
                "type":"boolean"
            },
            "data":
            {
                "type":"object",
                "properties":
                {
                    "dataset":
                    {
                        "type":"string"
                    },
                    "batch_size":
                    {
                        "type":"int"
                    },
                    "test_path":
                    {
                        "type":"string"
                    },
                    "train_path":
                    {
                        "type":"string"
                    }
                },
                "required":["dataset", "batch_size", "test_path", "train_path"]
            },
            "hpr_prm":
            {
                "lr":
                {
                    "type":"number"
                },
                "epochs":
                {
                    "type":"int"
                },
                "chanels":
                {
                    "type":"integer"
                },
                "required":["lr", "epochs", "chanels"]
            },
            "opt":
            {
                "type":"object",
                "properties":
                {
                    "displ":
                    {
                        "type":"object",
                        "properties":
                        {
                            "flag":
                            {
                                "type":"boolean"
                            },
                            "step":
                            {
                                "type":"integer"
                            }
                        },
                        "required":["flag", "step"]
                    },
                    "store":
                    {
                        "type":"object",
                        "properties":
                        {
                            "flag":
                            {
                                "type":"boolean"
                            },
                            "path":
                            {
                                "type":"string"
                            }
                        },
                        "required":["flag", "path"]
                    }
                },
                "required":["displ", "store"]
            }
        },
        "required":["flag", "data", "hpr_prm", "opt"]
    },
    "learn":
    {
        "type":"object",
        "properties":
        {
            "flag":
            {
                "type":"boolean"
            },
            "data":
            {
                "type":"object",
                "properties":
                {
                    "dataset":
                    {
                        "type":"string"
                    },
                    "batch_size":
                    {
                        "type":"int"
                    },
                    "test_path":
                    {
                        "type":"string"
                    },
                    "train_path":
                    {
                        "type":"string"
                    }
                },
                "required":["dataset", "batch_size", "test_path", "train_path"]
            },
            "hpr_prm":
            {
                "lr":
                {
                    "type":"number"
                },
                "epochs":
                {
                    "type":"int"
                },
                "chanels":
                {
                    "type":"integer"
                },
                "required":["lr", "epochs", "chanels"]
            },
            "opt":
            {
                "type":"object",
                "properties":
                {
                    "displ":
                    {
                        "type":"object",
                        "properties":
                        {
                            "flag":
                            {
                                "type":"boolean"
                            },
                            "step":
                            {
                                "type":"integer"
                            }
                        },
                        "required":["flag", "step"]
                    },
                    "store":
                    {
                        "type":"object",
                        "properties":
                        {
                            "flag":
                            {
                                "type":"boolean"
                            },
                            "path":
                            {
                                "type":"string"
                            }
                        },
                        "required":["flag", "path"]
                    },
                    "load":
                    {
                        "type":"object",
                        "properties":
                        {
                            "flag":
                            {
                                "type":"boolean"
                            },
                            "path":
                            {
                                "type":"string"
                            }
                        },
                        "required":["flag", "path"]
                    }
                },
                "required":["displ", "store", "load"]
            }
        },
        "required":["flag", "data", "hpr_prm", "opt"]
    },
    "required":["pretrain", "train"]   
}

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels:int, out_channels:int):
        super().__init__()
        self.doubl_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.doubl_conv(x)

In [None]:
class UNetAutoencoder(nn.Module):
    def __init__(self, in_channels, latent_dim):
        super().__init__()
        self.autoencoder = nn.Sequential(
            DoubleConv(in_channels, 64),
            DoubleConv(64, 128),
            (128, 256),
            (256, 512),
            DoubleConv(512, latent_dim),  
            nn.MaxPool2d(kernel_size=2, stride=2),    
            nn.ConvTranspose2d(latent_dim, 512, kernel_size=2, stride=2),
            DoubleConv(512, 512),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            DoubleConv(256, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            DoubleConv(128, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            DoubleConv(64, 64),
            nn.Conv2d(64, in_channels, kernel_size=1)
        )
        # Encoder (downsampling)
        self.down1 = DoubleConv(in_channels, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)
        self.down4 = DoubleConv(256, 512)
        self.down5 = DoubleConv(512, latent_dim)  # латентное представление

    def forward(self, x):
        return self.autoencoder(x)


In [None]:
def mask():
    pass

In [None]:
def get_dataset():
    pass

In [None]:
def pretrain(config:dict):
    pass

In [None]:
def train():
    pass


In [None]:
def test():
    pass

In [None]:
def learn(config:dict):
    if config['pretrain']['flag']:
        pretrain(config['pretrain'])
    train()
    test()