In [14]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable, gradcheck
import torch.optim as optim


In [2]:
import import_ipynb
import separable_convolution
import dataset_preparation

importing Jupyter notebook from separable_convolution.ipynb
importing Jupyter notebook from dataset_preparation.ipynb


In [3]:
# !pip install import-ipynb

In [4]:
OUTPUT_1D_KERNEL_SIZE: int = 51
EPOCHS: int = 10
BATCH_SIZE: int = 100

In [5]:
class Net(nn.Module):

    def __init__(self, init_weights=True):
        super(Net, self).__init__()

        conv_kernel = (3, 3)
        conv_stride = (1, 1)
        conv_padding = 1
        sep_kernel = OUTPUT_1D_KERNEL_SIZE

        self.pool = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.upsamp = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.relu = nn.ReLU()

        self.conv32 = self._conv_module(6, 32, conv_kernel, conv_stride, conv_padding, self.relu)
        self.conv64 = self._conv_module(32, 64, conv_kernel, conv_stride, conv_padding, self.relu)
        self.conv128 = self._conv_module(64, 128, conv_kernel, conv_stride, conv_padding, self.relu)
        self.conv256 = self._conv_module(128, 256, conv_kernel, conv_stride, conv_padding, self.relu)
        self.conv512 = self._conv_module(256, 512, conv_kernel, conv_stride, conv_padding, self.relu)
        self.conv512x512 = self._conv_module(512, 512, conv_kernel, conv_stride, conv_padding, self.relu)
        self.upsamp512 = self._upsample_module(512, 512, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv256 = self._conv_module(512, 256, conv_kernel, conv_stride, conv_padding, self.relu)
        self.upsamp256 = self._upsample_module(256, 256, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv128 = self._conv_module(256, 128, conv_kernel, conv_stride, conv_padding, self.relu)
        self.upsamp128 = self._upsample_module(128, 128, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv64 = self._conv_module(128, 64, conv_kernel, conv_stride, conv_padding, self.relu)
        self.upsamp64 = self._upsample_module(64, 64, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv51_1 = self._kernel_module(64, sep_kernel, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv51_2 = self._kernel_module(64, sep_kernel, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv51_3 = self._kernel_module(64, sep_kernel, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)
        self.upconv51_4 = self._kernel_module(64, sep_kernel, conv_kernel, conv_stride, conv_padding, self.upsamp, self.relu)

        self.pad = nn.ReplicationPad2d(sep_kernel // 2)

 
        self.separable_conv = separable_convolution.SeparableConvolutionSlow()

        if init_weights:
            print('Initializing weights...')
            self.apply(self._weight_init)

    @staticmethod
    def from_file(file_path: str) -> nn.Module :
        model = Net(init_weights=False)
        state_dict = torch.load(file_path)
        model.load_state_dict(state_dict)
        return model

    def to_file(self, file_path: str):
        torch.save(self.cpu().state_dict(), file_path)

#     def interpolate(self, *args):
#         return interpol.interpolate(self, *args)

#     def interpolate_f(self, *args):
#         return interpol.interpolate_f(self, *args)

#     def interpolate_batch(self, *args):
#         return interpol.interpolate_batch(self, *args)

    def forward(self, x):

        i1 = x[:, :3]
        i2 = x[:, 3:6]


        x = self.conv32(x)
        x = self.pool(x)

        x64 = self.conv64(x)
        x128 = self.pool(x64)

        x128 = self.conv128(x128)
        x256 = self.pool(x128)

        x256 = self.conv256(x256)
        x512 = self.pool(x256)

        x512 = self.conv512(x512)
        x = self.pool(x512)

        x = self.conv512x512(x)

        x = self.upsamp512(x)
        x += x512
        x = self.upconv256(x)

        x = self.upsamp256(x)
        x += x256
        x = self.upconv128(x)

        x = self.upsamp128(x)
        x += x128
        x = self.upconv64(x)

        x = self.upsamp64(x)
        x += x64

        k2h = self.upconv51_1(x)

        k2v = self.upconv51_2(x)

        k1h = self.upconv51_3(x)

        k1v = self.upconv51_4(x)

        padded_i2 = self.pad(i2)
        padded_i1 = self.pad(i1)

        return self.separable_conv(padded_i2, k2v, k2h) + self.separable_conv(padded_i1, k1v, k1h)

    @staticmethod
    def _check_gradients(func):
        print('Starting gradient check...')
        sep_kernel = OUTPUT_1D_KERNEL_SIZE
        inputs = (
            Variable(torch.randn(2, 3, sep_kernel, sep_kernel).cuda(), requires_grad=False),
            Variable(torch.randn(2, sep_kernel, 1, 1).cuda(), requires_grad=True),
            Variable(torch.randn(2, sep_kernel, 1, 1).cuda(), requires_grad=True),
        )
        test = gradcheck(func, inputs, eps=1e-3, atol=1e-3, rtol=1e-3)
        print('Gradient check result:', test)

    @staticmethod
    def _conv_module(in_channels, out_channels, kernel, stride, padding, relu):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, in_channels, kernel, stride, padding), relu,
            torch.nn.Conv2d(in_channels, in_channels, kernel, stride, padding), relu,
            torch.nn.Conv2d(in_channels, out_channels, kernel, stride, padding), relu,
        )

    @staticmethod
    def _kernel_module(in_channels, out_channels, kernel, stride, padding, upsample, relu):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, in_channels, kernel, stride, padding), relu,
            torch.nn.Conv2d(in_channels, in_channels, kernel, stride, padding), relu,
            torch.nn.Conv2d(in_channels, out_channels, kernel, stride, padding), relu,
            upsample,
            torch.nn.Conv2d(out_channels, out_channels, kernel, stride, padding)
        )

    @staticmethod
    def _upsample_module(in_channels, out_channels, kernel, stride, padding, upsample, relu):
        return torch.nn.Sequential(
            upsample, torch.nn.Conv2d(in_channels, out_channels, kernel, stride, padding), relu,
        )

    @staticmethod
    def _weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.orthogonal_(m.weight, init.calculate_gain('relu'))

In [7]:
if torch.cuda.is_available():
    print("===> CUDA available, proceeding with GPU...")
    device = torch.device("cuda")
else:
    print("===> No GPU found, proceeding with CPU...")
    device = torch.device("cpu")

===> No GPU found, proceeding with CPU...


In [8]:
torch.cuda.is_available()

False

In [10]:
# print('===> Loading datasets...')
# train_set = dataset_preparation.get_training_set()
# validation_set = dataset_preparation.get_validation_set()

# training_data_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE,shuffle=True)
# validation_data_loader = DataLoader(dataset=validation_set,batch_size=BATCH_SIZE, shuffle=False)

In [11]:
print('===> Building model...')
model = Net()
model.to(device)

===> Building model...
Initializing weights...


Net(
  (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  (upsamp): Upsample(scale_factor=2, mode=bilinear)
  (relu): ReLU()
  (conv32): Sequential(
    (0): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(6, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
  )
  (conv64): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
  )
  (conv128): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 

In [12]:
loss_function = nn.L1Loss()

In [15]:
optimizer = optim.Adamax(model.parameters(), lr=0.001)

In [19]:
def train(epoch):
    print("===> Training...")
    before_pass = [p.data.clone() for p in model.parameters()]
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()

        print('Forward pass...')
        output = model(input)

        loss_ = loss_function(output, target)

        print('Computing gradients...')
        loss_.backward()

        print('Gradients ready.')
        optimizer.step()

        loss_val = loss_.item()
        epoch_loss += loss_val

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss_val))

    weight_l2s = 0
    weight_diff_l2s = 0
    gradient_l2s = 0
    for i, p in enumerate(model.parameters()):
        weight_l2s += p.data.norm(2)
        weight_diff_l2s += (p.data - before_pass[i]).norm(2)
        gradient_l2s += p.grad.norm(2)
   
    epoch_loss /= len(training_data_loader)
    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss))

In [20]:
def validate(epoch):
    print("===> Running validation...")
    ssmi = loss.SsimLoss()
    valid_loss, valid_ssmi, valid_psnr = 0, 0, 0
    iters = len(validation_data_loader)
    with torch.no_grad():
        for batch in validation_data_loader:
            input, target = batch[0].to(device), batch[1].to(device)
            output = model(input)
            valid_loss += loss_function(output, target).item()
            valid_ssmi -= ssmi(output, target).item()
            valid_psnr += psnr(output, target).item()
    valid_loss /= iters
    valid_ssmi /= iters
    valid_psnr /= iters
    print("===> Validation loss: {:.4f}".format(valid_loss))

In [None]:
OUTPUT_DIR = "/...."

In [18]:
def save_checkpoint(epoch):
    model_out_path = "model_epoch_{}.pth".format(epoch)
    model_out_path = join_paths(OUTPUT_DIR, model_out_path)
    model_latest_path = join_paths(OUTPUT_DIR, 'model_epoch_latest.pth')
    if not exists(OUTPUT_DIR):
        makedirs(OUTPUT_DIR)
    torch.save(model.cpu().state_dict(), model_out_path)
    if exists(model_latest_path):
        remove(model_latest_path)
    link(model_out_path, model_latest_path)
    print("Checkpoint saved to {}".format(model_out_path))
    if device.type != 'cpu':
        model.cuda()

In [None]:
EPOCHS = 10

In [21]:
# for epoch in range(1, EPOCHS + 1):
#     train(epoch)
#     save_checkpoint(epoch)
#     validate(epoch)