Drive Setup and File Configuration

In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount = True)

In [None]:
# Create Folder for Datasets
# !mkdir /content/drive/MyDrive/Dataset

# # Unzip training data
# !cp /content/drive/MyDrive/DesnowNet/Snow100K-training.zip /content/drive/MyDrive/Dataset
# !zip -FFv /content/drive/MyDrive/Dataset/Snow100K-training.zip --out /content/drive/MyDrive/Dataset/Snow100K-training2.zip # Can't unzip original; -FFv fixes this

# Repeat for testing set
# !cp /content/drive/MyDrive/DesnowNet/Snow100K-testset.zip /content/drive/MyDrive/Dataset
# !zip -FFv /content/drive/MyDrive/Dataset/Snow100K-testset.zip --out /content/drive/MyDrive/Dataset/Snow100K-testset2.zip 

In [None]:
# # Storing the dataset locally makes indexing much faster 
# # than going through Drive for every file
# !mkdir /content/Dataset

# !cp /content/drive/MyDrive/Dataset/Snow100K-training2.zip /content/Dataset
# !unzip -q /content/Dataset/Snow100K-training2.zip -d /content/Dataset
# !rm /content/Dataset/Snow100K-training2.zip

# !cp /content/drive/MyDrive/Dataset/Snow100K-testset2.zip /content/Dataset
# !unzip -q /content/Dataset/Snow100K-testset2.zip -d /content/Dataset
# !rm /content/Dataset/Snow100K-testset2.zip

In [None]:
dirs = {
    'trainset_root': '/content/Dataset/Snow100K-training/all',
    'testset_root': '/content/Dataset/Snow100K-testset/media/jdway/GameSSD/overlapping/test/Snow100K-L', # Change last letter to one of "S" "M" "L" to change test set
    'checkpoint_exists': True, # Set to False if not loading from a checkpoint
    'checkpoint_path_write': '/content/drive/MyDrive/checkpoints',
    'checkpoint_path_read': '/content/drive/MyDrive/checkpoints/final_checkpoint.pt'
}

Imports


In [None]:
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
import sys
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import time
import math
import dill
from tqdm.notebook import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

Load and Preprocess the Data

In [None]:
# Adapted from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomCrop
def DesnowPadCrop(img, size, i, j, padding_mode='symmetric'):
    
    width, height = TF.get_image_size(img)    
    # pad the width if needed
    if width < size:
        padding = [size - width, 0]
        img = TF.pad(img, padding, padding_mode)
    # pad the height if needed
    if height < size:
        padding = [0, size - height]
        img = TF.pad(img, padding, padding_mode)
    
    return TF.crop(img, i, j, size, size)
    

In [None]:
# Load the Dataset
class DesnowDataset(data.Dataset):
    def __init__(self, main_dir, img_names_list, transform = None):
        super(DesnowDataset, self).__init__()
        
        self.main_dir = main_dir
        self.transform = transform
        self.n_imgs = len(img_names_list) 
        
        # os.path.join over string concatenation for OS flexibility
        self.gt_imgs = [os.path.join(main_dir, 'gt', img_filename) for img_filename in img_names_list]
        self.mask_imgs = [os.path.join(main_dir, 'mask', img_filename) for img_filename in img_names_list]
        self.snowy_imgs = [os.path.join(main_dir, 'synthetic', img_filename) for img_filename in img_names_list]

    
    def __len__(self):
        return self.n_imgs
    
    def __getitem__(self, idx):
        
        # PIL Images for torch.transforms
        gt_img = Image.open(self.gt_imgs[idx])
        mask_img = Image.open(self.mask_imgs[idx])
        snowy_img = Image.open(self.snowy_imgs[idx])
        
        # Apply transformations (if any)
        if self.transform:
            gt_img = self.transform(gt_img)
            mask_img = self.transform(mask_img)
            snowy_img = self.transform(snowy_img)
            
        # Randomly crop and pad if needed
        h, w = TF.get_image_size(gt_img)
#         i = torch.randint(0, h - 64 + 1, size=(1,)).item()
#         j = torch.randint(0, w - 64 + 1, size=(1,)).item()
        i = random.randint(0, h-64)
        j = random.randint(0, w-64)
        gt_img = DesnowPadCrop(gt_img, 64, i, j, padding_mode='symmetric')
        mask_img = DesnowPadCrop(mask_img, 64, i, j, padding_mode='symmetric')
        snowy_img = DesnowPadCrop(snowy_img, 64, i, j, padding_mode='symmetric')
        
        # Convert to tensor
#         gt_img = transforms.PILToTensor(gt_img)
#         mask_img = transforms.PILToTensor(mask_img)
#         snowy_img = transforms.PILToTensor(snowy_img)

        # Change data type to float32 and standardize manually
        # ToTensor() and PILToTensor() unintuitive
        gt_arr = np.array(gt_img, dtype=np.float32) / 255
        mask_arr = np.array(mask_img, dtype=np.float32) / 255
        snowy_arr = np.array(snowy_img, dtype=np.float32) / 255
        
        # Convert to tensor and return
        gt_img = torch.from_numpy(gt_arr).permute((2,0,1))
        mask_img = torch.from_numpy(mask_arr).permute((2,0,1))
        snowy_img = torch.from_numpy(snowy_arr).permute((2,0,1))

        return (gt_img, mask_img, snowy_img)

Inception v4

In [None]:
# Adapted from https://github.com/Cadene/pretrained-models.pytorch
# Note that the original model is based on classification; the code has been modified to fit our needs
# We also change pooling operations to be "same" ()

class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, bias=False) # verify bias false
        self.bn = nn.BatchNorm2d(out_channels,
#                                  eps=0.001, # value found in tensorflow
                                 eps=1e-05, # default pytroch value
                                 momentum=0.1, # default pytorch value
                                 affine=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Mixed_3a(nn.Module):

    def __init__(self):
        super(Mixed_3a, self).__init__()
        self.maxpool = nn.MaxPool2d(3, stride=1, padding=1)
        self.conv = BasicConv2d(16, 24, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x0 = self.maxpool(x)
        x1 = self.conv(x)
        out = torch.cat((x0, x1), 1)
        return out


class Mixed_4a(nn.Module):

    def __init__(self):
        super(Mixed_4a, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv2d(40, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 24, kernel_size=3, stride=1, padding=1)
        )

        self.branch1 = nn.Sequential(
            BasicConv2d(40, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 16, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(16, 16, kernel_size=(7,1), stride=1, padding=(3,0)),
            BasicConv2d(16, 24, kernel_size=(3,3), stride=1, padding=1)
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        return out


class Mixed_5a(nn.Module):

    def __init__(self):
        super(Mixed_5a, self).__init__()
        self.conv = BasicConv2d(48, 48, kernel_size=3, stride=1, padding=1)
        self.maxpool = nn.MaxPool2d(3, stride=1, padding=1)

    def forward(self, x):
        x0 = self.conv(x)
        x1 = self.maxpool(x)
        out = torch.cat((x0, x1), 1)
        return out


class Inception_A(nn.Module):

    def __init__(self):
        super(Inception_A, self).__init__()
        self.branch0 = BasicConv2d(96, 24, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(96, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 24, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(96, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 24, kernel_size=3, stride=1, padding=1),
            BasicConv2d(24, 24, kernel_size=3, stride=1, padding=1)
        )

        self.branch3 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
            BasicConv2d(96, 24, kernel_size=1, stride=1)
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Reduction_A(nn.Module):

    def __init__(self):
        super(Reduction_A, self).__init__()
        self.branch0 = BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(96, 48, kernel_size=1, stride=1),
            BasicConv2d(48, 56, kernel_size=3, stride=1, padding=1),
            BasicConv2d(56, 64, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.MaxPool2d(3, stride=1, padding=1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        return out


class Inception_B(nn.Module):

    def __init__(self):
        super(Inception_B, self).__init__()
        self.branch0 = BasicConv2d(256, 96, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(256, 48, kernel_size=1, stride=1),
            BasicConv2d(48, 56, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(56, 64, kernel_size=(7,1), stride=1, padding=(3,0))
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(256, 48, kernel_size=1, stride=1),
            BasicConv2d(48, 48, kernel_size=(7,1), stride=1, padding=(3,0)),
            BasicConv2d(48, 56, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(56, 56, kernel_size=(7,1), stride=1, padding=(3,0)),
            BasicConv2d(56, 64, kernel_size=(1,7), stride=1, padding=(0,3))
        )

        self.branch3 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
            BasicConv2d(256, 32, kernel_size=1, stride=1)
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class Reduction_B(nn.Module):

    def __init__(self):
        super(Reduction_B, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv2d(256, 48, kernel_size=1, stride=1),
            BasicConv2d(48, 48, kernel_size=3, stride=1, padding=1)
        )

        self.branch1 = nn.Sequential(
            BasicConv2d(256, 64, kernel_size=1, stride=1),
            BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(64, 80, kernel_size=(7,1), stride=1, padding=(3,0)),
            BasicConv2d(80, 80, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.MaxPool2d(3, stride=1, padding=1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        return out


class Inception_C(nn.Module):

    def __init__(self):
        super(Inception_C, self).__init__()

        self.branch0 = BasicConv2d(384, 64, kernel_size=1, stride=1)

        self.branch1_0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
        self.branch1_1a = BasicConv2d(96, 64, kernel_size=(1,3), stride=1, padding=(0,1))
        self.branch1_1b = BasicConv2d(96, 64, kernel_size=(3,1), stride=1, padding=(1,0))

        self.branch2_0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
        self.branch2_1 = BasicConv2d(96, 112, kernel_size=(3,1), stride=1, padding=(1,0))
        self.branch2_2 = BasicConv2d(112, 128, kernel_size=(1,3), stride=1, padding=(0,1))
        self.branch2_3a = BasicConv2d(128, 64, kernel_size=(1,3), stride=1, padding=(0,1))
        self.branch2_3b = BasicConv2d(128, 64, kernel_size=(3,1), stride=1, padding=(1,0))

        self.branch3 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
            BasicConv2d(384, 64, kernel_size=1, stride=1)
        )

    def forward(self, x):
        x0 = self.branch0(x)

        x1_0 = self.branch1_0(x)
        x1_1a = self.branch1_1a(x1_0)
        x1_1b = self.branch1_1b(x1_0)
        x1 = torch.cat((x1_1a, x1_1b), 1)

        x2_0 = self.branch2_0(x)
        x2_1 = self.branch2_1(x2_0)
        x2_2 = self.branch2_2(x2_1)
        x2_3a = self.branch2_3a(x2_2)
        x2_3b = self.branch2_3b(x2_2)
        x2 = torch.cat((x2_3a, x2_3b), 1)

        x3 = self.branch3(x)

        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class InceptionV4(nn.Module):

    def __init__(self, num_channels = 3):
        super(InceptionV4, self).__init__()
        # Special attributes
        self.input_space = None
        self.input_size = (299, 299, 3)
        self.mean = None
        self.std = None
        # Modules
        self.features = nn.Sequential(
            # Change to "Same" Convolutions (kernel size 3, stride 1, padding 1)
            BasicConv2d(num_channels, 8, kernel_size=3, stride=1, padding=1),
            BasicConv2d(8, 8, kernel_size=3, stride=1, padding=1),
            BasicConv2d(8, 16, kernel_size=3, stride=1, padding=1),
            Mixed_3a(),
            Mixed_4a(),
            Mixed_5a(),
            Inception_A(),
            Inception_A(),
            Inception_A(),
            Inception_A(),
            Reduction_A(), # Mixed_6a
            Inception_B(),
            Inception_B(),
            Inception_B(),
            Inception_B(),
            Inception_B(),
            Inception_B(),
            Inception_B(),
            Reduction_B(), # Mixed_7a
            Inception_C(),
            Inception_C(),
            Inception_C()
        )
#         self.last_linear = nn.Linear(1536, num_classes)
        
#     # Dimensionality reduction from 3D to 1D
#     def logits(self, features):
#         #Allows image of any size to be processed
#         adaptiveAvgPoolWidth = features.shape[2]
#         x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth)
#         x = x.view(x.size(0), -1)
#         x = self.last_linear(x)
#         return x

    def forward(self, input):
        x = self.features(input)
#         x = self.logits(x)
        return x

Dilation Pyramid

In [None]:
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/deeplabv3.py
# Pooling and Projection removed; we basically just need Conv2d, BatchNorm2d, and ReLU

class ASPPConv(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ]
        super().__init__(*modules)
        

class ASPP(nn.Module):
    def __init__(self, in_channels: int, atrous_rates, out_channels: int) -> None: #list instead of List
        super().__init__()
        modules = []
        modules.append(
            nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
        )

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        self.convs = nn.ModuleList(modules)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _res = []
        for conv in self.convs:
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
        return res 

Pyramid Maxout

In [None]:
# Hard coding for beta = 4, specified in the paper
class PyramidMaxout_b4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PyramidMaxout_b4, self).__init__()
        
        self.prelu = nn.PReLU() # "weight decay should not be used when learning alpha for good performance"
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
        self.conv_3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.conv_5 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=(1,5), stride=1, padding=(0,2), padding_mode='replicate'),
            nn.Conv2d(in_channels, out_channels, kernel_size=(5,1), stride=1, padding=(2,0), padding_mode='replicate'))
        self.conv_7 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=(1,7), stride=1, padding=(0,3), padding_mode='replicate'),
            nn.Conv2d(in_channels, out_channels, kernel_size=(7,1), stride=1, padding=(3,0), padding_mode='replicate'))
        
    def forward(self, x):
        c1 = self.conv_1(x)
        c3 = self.conv_3(x)
        c5 = self.conv_5(x)
        c7 = self.conv_7(x)
        
        max_13 = torch.fmax(c1, c3)
        max_57 = torch.fmax(c5, c7)
        max_all = torch.fmax(max_13, max_57)
        
        return self.prelu(max_all)

Translucency Recovery

In [None]:
# Generate SE and AE
class Rt(nn.Module):
    def __init__(self, in_channels):
        super(Rt, self).__init__()
        
        self.SE = PyramidMaxout_b4(in_channels, 1) # z hat
        self.AE = PyramidMaxout_b4(in_channels, 3) # a
        
    def forward(self, ft, x):
        a = self.AE(ft)
        z = self.SE(ft)
        z_broadcasted = torch.cat((z, z, z), dim = 1)

        y_prime = (x - a * z_broadcasted) / (1 + 1e-5 - z_broadcasted)  # Case 1
        y_prime[z_broadcasted >= 1] = x[z_broadcasted >= 1]             # Case 2

        fc = torch.cat([y_prime, z_broadcasted, a], dim = 1)
    
        return z_broadcasted, y_prime, fc

Residual Generation

In [None]:
# Hard coding for beta = 4, specified in the paper
class PyramidSum_b4(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PyramidSum_b4, self).__init__()
        
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
        self.conv_3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.conv_5 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=(1,5), stride=1, padding=(0,2), padding_mode='replicate'),
            nn.Conv2d(in_channels, out_channels, kernel_size=(5,1), stride=1, padding=(2,0), padding_mode='replicate'))
        self.conv_7 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=(1,7), stride=1, padding=(0,3), padding_mode='replicate'),
            nn.Conv2d(in_channels, out_channels, kernel_size=(7,1), stride=1, padding=(3,0), padding_mode='replicate'))
        
    def forward(self, x):
        c1 = self.conv_1(x)
        c3 = self.conv_3(x)
        c5 = self.conv_5(x)
        c7 = self.conv_7(x)
        
        return c1 + c3 + c5 + c7

Loss Function

In [None]:
# P_n denotes the max-pooling operation with kernel size nxn and stride nxn for non-overlapped pooling

class PyramidLossPool(nn.Sequential):
    def __init__(self, n: int) -> None:
        modules = [
            nn.MaxPool2d(kernel_size=n, stride=n)
        ]
        super().__init__(*modules)
        

class PyramidLoss(nn.Module):
    def __init__(self, tau: int) -> None: #list instead of List
        super().__init__()
        modules = []

        sizes = tuple(range(1, tau*2, 2))
        for n in sizes:
            modules.append(PyramidLossPool(n))

        self.pools = nn.ModuleList(modules)
        self.mse = nn.MSELoss()

    def forward(self, x, y):
        cum_loss = 0
        for pools in self.pools:
            cum_loss += self.mse(pools(x), pools(y))
        
        return cum_loss

Xavier Initialization

In [None]:
# Adapted (directly) from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

Full Model

In [None]:
class DesnowNet(nn.Module):
    def __init__(self):
        super(DesnowNet, self).__init__()
        
        self.Dt = InceptionV4(3)
        self.DP = ASPP(384, [2,4,8,16], 192)
        self.Dr = InceptionV4(9)
        self.Rt = Rt(960)
        self.Rr = PyramidSum_b4(384, 3)
        
        # We have to initialize the submodules
        self.Dt.apply(weight_init)
        self.DP.apply(weight_init)
        self.Dr.apply(weight_init)
        self.Rt.apply(weight_init)
        self.Rr.apply(weight_init)
        
    def forward(self, x):
        Dt_out = self.Dt(x)
        ft = self.DP(Dt_out)
        z_hat, y_prime, fc = self.Rt(ft, x)
        fr = self.Dr(fc)
        r = self.Rr(fr)
        y_hat = y_prime + r
        return y_hat, y_prime, z_hat

Initialize Datasets and Model

In [None]:
# Adapted from https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/train.py

# img_paths = dill.load(open('/content/drive/MyDrive/Dataset/test_list_L', 'rb'))
img_paths = os.listdir(os.path.join(dirs['trainset_root'], 'gt'))
trainset = DesnowDataset(main_dir = dirs['trainset_root'], img_names_list = img_paths)
trainset_dataloader = data.DataLoader(dataset=trainset,
                                     batch_size = 5,
                                     shuffle=True,
                                     drop_last=False,
                                     pin_memory=True)

# test_paths = dill.load(open('/content/drive/MyDrive/Dataset/test_list_L', 'rb'))
test_paths = os.listdir(os.path.join(dirs['testset_root'], 'gt'))
testset = DesnowDataset(main_dir = dirs['testset_root'], img_names_list = test_paths)
testset_dataloader = data.DataLoader(dataset=testset,
                                    batch_size = 5,
                                    shuffle=True,
                                    drop_last=False,
                                    pin_memory=True)

In [None]:
model = DesnowNet()
model.to(device)

# --- Build optimizer --- #
optimizer = optim.Adam(model.parameters(),
                       lr = 3e-5,
                       weight_decay = 5e-4 #λw; this accounts for the final term of Eq (8) 
                      )

# --- Define the loss function (criterion) --- #
criterion = PyramidLoss(tau = 4).to(device) # Tau is unspecified; consider using 4 similar to beta in Pyramid Sum/Max
MSE = nn.MSELoss().to(device)     # For PSNR calculation (consdier also SSIM)

Load Checkpoint

In [None]:
# Remember to set dirs['checkpoint_exists'] to True after the first epoch
if dirs['checkpoint_exists']:
    checkpoint = torch.load(dirs['checkpoint_path_read'], map_location=device)
    model.load_state_dict(checkpoint['model state'])
    optimizer.load_state_dict(checkpoint['optimizer state'])
    start_epoch = checkpoint['epoch']
    # loss = checkpoint['Train loss']
    print("Loaded epoch", start_epoch)
else:
    start_epoch = 1

Train

In [None]:
# Adapted from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

for epoch in range(start_epoch + 1, 999): # Will have to stop manually
    running_loss = 0
    avg_loss = 0
    
    model.train()
    training_loop = tqdm(trainset_dataloader, leave=False, position=0)
    for iter_idx, img_data in enumerate(training_loop):

        # Instantiate data
        gt_img, mask_img, snowy_img = img_data
        gt_img, mask_img, snowy_img = gt_img.to(device), mask_img.to(device), snowy_img.to(device)
    
        # Zero your gradients for every batch!
        optimizer.zero_grad()
    
        # Make predictions for this batch
        y_hat, y_prime, z_hat = model(snowy_img)
    
        # Compute the loss and its gradients
        loss = criterion(y_prime, gt_img) + criterion(y_hat, gt_img) + 3 * criterion(z_hat, mask_img)
        running_loss += loss.item()
        avg_loss = running_loss / (iter_idx + 1)
        loss.backward()
    
        # Adjust learning weights
        optimizer.step()

    print("Training Loss", avg_loss)
         
    torch.save({
        'epoch': epoch,
        'model state': model.state_dict(),
        'optimizer state': optimizer.state_dict(),
        'Train loss': sum(epoch_loss)/len(epoch_loss)},
        os.path.join(dirs['checkpoint_path_write'], f'model_epoch_{epoch}.pt'))\

    print("Saved state at epoch", epoch)

Test

In [None]:
model.eval()
running_psnr = 0
avg_psnr = 0
    
test_loop = tqdm(testset_dataloader, leave=False, position=0)
for iter_idx, img_data in enumerate(test_loop):

    # Forward propagation
    gt_img, mask_img, snowy_img = img_data
    gt_img, mask_img, snowy_img = gt_img.to(device), mask_img.to(device), snowy_img.to(device)
    y_hat, y_prime, z_hat = model(snowy_img)
            
    # Clip output
    y_hat = torch.clamp(y_hat, min=0, max=1)
            
    # Compute MSE and PSNR
    loss = MSE(y_hat, gt_img)
    running_psnr += (10 * np.log10(1/loss.item()))
    avg_psnr = running_psnr / (iter_idx + 1)
            
print("PSNR: ", avg_psnr)