In [1]:
# cd /kaggle/working

In [2]:
# !git clone https://github.com/naver/r2d2

In [3]:
# cd r2d2


In [1]:
import os, pdb
import torch
import torch.optim as optim

from tools import common, trainer
from tools.dataloader import *
from nets.patchnet import *
from nets.losses import *
from PIL import Image, ImageOps
from tqdm.notebook import tqdm 


In [2]:
# try:
#     DataPreparation
# except NameError:
#     os.system("cp -r /kaggle/input/naturalimages/ /tmp/dataset")
#     os.system("ln -s /tmp/dataset /kaggle/working/r2d2/data")
#     DataPreparation = True
# else:
#     pass

In [3]:
# !cp -r /kaggle/input/naturalimages/ /tmp/dataset
# !ln -s /tmp/dataset /kaggle/working/r2d2/data

In [4]:
# ls /tmp/dataset

In [5]:
# ls

In [7]:
from  datasets import *

In [8]:
class StillTransform (object):
    """ Takes and return an image, without changing its shape or geometry.
    """
    def _transform(self, img):
        raise NotImplementedError()
        
    def __call__(self, inp):
        img = F.grab_img(inp)

        # transform the image (size should not change)
        try:
            img = self._transform(img)
        except TypeError:
            pass

        return F.update_img_and_labels(inp, img, persp=(1,0,0,0,1,0,0,0))

    
class PixelSpeckleNoise (StillTransform):
    """ Takes an image, and add random white noise.
    """
    def __init__(self, var=.05, seed=None):
        StillTransform.__init__(self)
        assert 0 <= var < 1
        self.var = var
        self.seed = seed

    def __repr__(self):
        return "PixelSpeckleNoise(%g)" % self.var
    
    def normalize(self, img,minimum=0, maximum=1):
        img_max = np.max(img)
        img_min = np.min(img)
        return (img-img_min)/np.abs(img_max-img_min)*(maximum-minimum)+minimum

    def _transform(self, img):
        normalized_img = self.normalize(img)
        upper_band = (12*self.var)**.5
        np.random.seed(self.seed)
        noise = np.random.uniform(-upper_band/2,upper_band/2,size=img.shape)
        noisy_img = normalized_img*(1+noise)
        noisy_img = np.clip(noisy_img,0,1)
        ret_val = self.normalize(noisy_img,maximum=255)
        return  Image.fromarray(np.uint8(ret_val))
    
class PixelNoise (StillTransform):
    """ Takes an image, and add random white noise.
    """
    def __init__(self, ampl=20):
        StillTransform.__init__(self)
        assert 0 <= ampl < 255
        self.ampl = ampl

    def __repr__(self):
        return "PixelNoise(%g)" % self.ampl

    def _transform(self, img):
        img = np.float32(img)
        img += np.random.uniform(0.5-self.ampl/2, 0.5+self.ampl/2, size=img.shape)
        return Image.fromarray(np.uint8(img.clip(0,255)))


In [9]:
default_net = "Quad_L2Net_ConfCFS()"

toy_db_debug = """SyntheticPairDataset(
    ImgFolder('imgs'), 
            'RandomScale(256,1024,can_upscale=True)', 
            'RandomTilting(0.5), PixelSpeckleNoise(.5)')"""

db_web_images = """SyntheticPairDataset(
    web_images, 
        'RandomScale(256,1024,can_upscale=True)',
        'RandomTilting(0.5), PixelSpeckleNoise(.5)')"""

db_aachen_images = """SyntheticPairDataset(
    aachen_db_images, 
        'RandomScale(256,1024,can_upscale=True)', 
        'RandomTilting(0.5), PixelSpeckleNoise(.5)')"""

db_aachen_style_transfer = """TransformedPairs(
    aachen_style_transfer_pairs,
            'RandomScale(256,1024,can_upscale=True), RandomTilting(0.5), PixelSpeckleNoise(.5)')"""

db_aachen_flow = "aachen_flow_pairs"


db_sar_images = """SyntheticPairDataset(
    sar_db_images, 
        'RandomScale(256,256,can_upscale=False)', 
        'RandomTilting(0.5), PixelSpeckleNoise(.5)')"""



In [10]:

default_dataloader = """PairLoader(CatPairDataset(`data`),
    scale   = 'RandomScale(256,1024,can_upscale=True)',
    distort = 'ColorJitter(0.2,0.2,0.2,0.1)',
    crop    = 'RandomCrop(192)')"""

default_sampler = """NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16,
                            subd_neg=-8,maxpool_pos=True)"""

default_loss = """MultiLoss(
        1, ReliabilityLoss(`sampler`, base=0.5, nq=20),
        1, CosimLoss(N=`N`),
        1, PeakyLoss(N=`N`))"""


In [14]:
data_sources = dict(
    D = toy_db_debug,
    W = db_web_images,
    A = db_aachen_images,
    F = db_aachen_flow,
    S = db_aachen_style_transfer,
    X = db_sar_images
    )


In [12]:
class MyTrainer(trainer.Trainer):
    """ This class implements the network training.
        Below is the function I need to overload to explain how to do the backprop.
    """
    def forward_backward(self, inputs):
        output = self.net(imgs=[inputs.pop('img1'),inputs.pop('img2')])
        allvars = dict(inputs, **output)
        loss, details = self.loss_func(**allvars)
        if torch.is_grad_enabled(): loss.backward()
        return loss, details



In [15]:

def load_network(model_fn): 
    checkpoint = torch.load(model_fn)
    print("\n>> Creating net = " + checkpoint['net']) 
    net = eval(checkpoint['net'])
    nb_of_weights = common.model_size(net)
    print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )")

    # initialization
    weights = checkpoint['state_dict']
    net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()})
    return net.eval()


In [16]:
mkdir trained_models

In [17]:
save_path = "./trained_models"
gpu = 0
train_data = "WASFX"
data_loader = default_dataloader
threads = 8
batch_size = 8
net = default_net
sampler = default_sampler
N = patch_size = 16 
loss = default_loss
learning_rate = 1e-4
weight_decay = 5e-4
epochs = 10
network_path = "./models/faster2d2_WASF_N16.pt"

In [18]:
iscuda = common.torch_set_gpu(gpu)
common.mkdir_for(save_path)


Launching on GPUs 0


In [19]:
# Create data loader
db = [data_sources[key] for key in train_data]
db = eval(data_loader.replace('`data`',','.join(db)).replace('\n',''))
print("Training image database =", db)
loader = threaded_loader(db, iscuda, threads, batch_size, shuffle=True)


NameError: name 'web_images' is not defined

In [None]:
db

In [None]:
net = load_network(network_path)

In [None]:
# # initialization
# pretrained = "./models/faster2d2_WASF_N16.pt"
# checkpoint = torch.load(pretrained, lambda a,b:a)
# net.load_pretrained(checkpoint['state_dict'])


In [None]:
# create losses
loss = loss.replace('`sampler`',sampler).replace('`N`',str(patch_size))
print("\n>> Creating loss = " + loss)
loss = eval(loss.replace('\n',''))


In [None]:
# create optimizer
optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad], 
                        lr=learning_rate, weight_decay=weight_decay)

train = MyTrainer(net, loader, loss, optimizer)
if iscuda: train = train.cuda()


In [None]:
# Training loop #
for epoch in range(epochs):
    print(f"\n>> Starting epoch {epoch}...")
    train()

print(f"\n>> Saving model to {save_path}")
torch.save({'net': args.net, 'state_dict': net.state_dict()}, save_path)

