In [1]:
import torch
from torch import nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from fastai.vision.all import URLs
from batchflow import Pipeline, B, C, D, F, V, W
from batchflow.models.torch import EncoderDecoder
from batchflow.models.torch.layers import ConvBlock
from batchflow.models.torch.layers.modules import ASPP, PyramidPooling
from batchflow.opensets import Imagenette160

from train_module import training_functions

GRAPH_PATH = "./data/graphs/"
IMAGE_SHAPE = (3, 160, 160)

In [2]:
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [3]:
dataset = Imagenette160(bar='n')

 50%|█████     | 1/2 [00:05<00:05,  5.25s/it]


In [4]:
in_channels = 3
out_channels = 3
device = "cuda:3"

batch_size = 64
epoch_num = 5

In [5]:
downsample_depth = 2
in_channels = 3
config = { 
    #'device': device,
    'inputs/images/shape': (3, 160, 160), # can be commented
    'initial_block/inputs': 'images', # can be commented   
    'initial_block': {
        'layout': 'cna cna',
        'strides': [2, 1],
        'filters': [32, 64]
    },

    'body/encoder': {
        'num_stages': downsample_depth,
        'order': ['downsampling', 'skip']
    },    
    'body/encoder/downsample': {
        'layout': 'R' + 'wnacna' * 3 + '|',
        'filters': 'same',
        'strides': [[1, 1, 1, 1, 2, 1]] * downsample_depth,
        'kernel_size': [[3, 1] * 3] * downsample_depth,
        'dilation_rate' : 2,
        'branch': {
            'layout': 'cn', 
            'kernel_size': 1,
            'filters': 'same',
            'strides': 2
        },
    },      

    'body/embedding': {
        'base' : PyramidPooling #ASPP
    },  

    'body/decoder': {
        'num_stages': downsample_depth,
        'order': ['block', 'combine', 'upsampling'] 
    },

    'body/decoder/blocks': {
        'layout': 'cna',
        'kernel_size': 1,
        'filters': 256
    },

    'body/decoder/upsample': {
        'layout': 'b',
        'scale_factor': 2
    },
    'body/decoder/combine': {
        'op': 'concat',
    },
    
    
    'head':{
        'layout': 'cna b',
        'scale_factor': 2,
        'filters': in_channels,
        'activation': 'sigmoid'
    },

    'loss': 'mse',
    'optimizer': 'Adam'
}

In [6]:
model = EncoderDecoder(config)

In [7]:
model.short_repr()

Sequential(
  (initial_block): ConvBlock
  layout=cnacna
    Layer 0,  letter "c": (None, 3, 160, 160) -> (None, 32, 80, 80)
    Layer 1,  letter "n": (None, 32, 80, 80) -> (None, 32, 80, 80)
    Layer 2,  letter "a": (None, 32, 80, 80) -> (None, 32, 80, 80)
    Layer 3,  letter "c": (None, 32, 80, 80) -> (None, 64, 80, 80)
    Layer 4,  letter "n": (None, 64, 80, 80) -> (None, 64, 80, 80)
    Layer 5,  letter "a": (None, 64, 80, 80) -> (None, 64, 80, 80)
    
  (body): Sequential(
    (encoder): EncoderModule(
      (downsample-0): ConvBlock
      layout=Rwnacnawnacnawnacna|
        Layer 0,    skip "R": (None, 64, 80, 80) -> (None, 64, 40, 40)
        Layer 1,  letter "w": (None, 64, 80, 80) -> (None, 64, 80, 80)
        Layer 2,  letter "n": (None, 64, 80, 80) -> (None, 64, 80, 80)
        Layer 3,  letter "a": (None, 64, 80, 80) -> (None, 64, 80, 80)
        Layer 4,  letter "c": (None, 64, 80, 80) -> (None, 64, 80, 80)
        Layer 5,  letter "n": (None, 64, 80, 80) -> (None, 64,

In [8]:
model.model

Sequential(
  (initial_block): ConvBlock(
    (0): BaseConvBlock(
      layout=cnacna
      
      (Layer 0,  letter "c": (None, 3, 160, 160) -> (None, 32, 80, 80)): Conv(
        (layer): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      )
      (Layer 1,  letter "n": (None, 32, 80, 80) -> (None, 32, 80, 80)): BatchNorm(
        (layer): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Layer 2,  letter "a": (None, 32, 80, 80) -> (None, 32, 80, 80)): Activation(
        (activation): ReLU(inplace=True)
      )
      (Layer 3,  letter "c": (None, 32, 80, 80) -> (None, 64, 80, 80)): Conv(
        (layer): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
      )
      (Layer 4,  letter "n": (None, 64, 80, 80) -> (None, 64, 80, 80)): BatchNorm(
        (layer): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Layer 5,  letter "a": (None, 64, 80, 80) -> (None, 64, 80, 80