In [1]:
!cp -r /kaggle/input/mprnetv2/MPRNet /kaggle/working/

In [2]:
%cd /kaggle/working/MPRNet/Deraining/

/kaggle/working/MPRNet/Deraining


In [3]:
!pip install yacs warmup_scheduler

Collecting yacs
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting warmup_scheduler
  Downloading warmup_scheduler-0.3.tar.gz (2.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Building wheels for collected packages: warmup_scheduler
  Building wheel for warmup_scheduler (setup.py) ... [?25l[?25hdone
  Created wheel for warmup_scheduler: filename=warmup_scheduler-0.3-py3-none-any.whl size=2971 sha256=ec6ace5a61ac6a3d482b387318c15abf32f1aaa45071c62c98c254d3f940daf0
  Stored in directory: /root/.cache/pip/wheels/cc/5c/3b/6e5033100e0e4191383dad5c4279638a37f9791d1af9e1d85c
Successfully built warmup_scheduler
Installing collected packages: warmup_scheduler, yacs
Successfully installed warmup_scheduler-0.3 yacs-0.1.8


In [4]:
!mkdir -p /kaggle/working/MPRNet/Deraining/checkpoints/Deraining/models/MPRNet

In [5]:
!cp -r /kaggle/input/mprnet-checkpoints-rain14000/model_epoch_20.pth /kaggle/working/MPRNet/Deraining/checkpoints/Deraining/models/MPRNet

In [6]:
import yaml

yaml_path = '/kaggle/working/MPRNet/Deraining/training.yml'

# load file
with open(yaml_path, 'r') as f:
    cfg = yaml.safe_load(f)

# chỉnh sửa config
cfg['GPU'] = [0]
cfg['OPTIM']['BATCH_SIZE'] = 4
cfg['OPTIM']['NUM_EPOCHS'] = 100
cfg['OPTIM']['LR_INITIAL'] = 2e-4
cfg['OPTIM']['LR_MIN'] = 1e-6

cfg['TRAINING']['TRAIN_DIR'] = '/kaggle/input/rain14000/train'
cfg['TRAINING']['VAL_DIR'] = '/kaggle/input/rain14000/val'
cfg['TRAINING']['SAVE_DIR'] = '/kaggle/working/checkpoints'

cfg['TRAINING']['RESUME'] = True

# save lại file
with open(yaml_path, 'w') as f:
    yaml.dump(cfg, f)

print("Updated training.yml")


Updated training.yml


In [7]:
%%writefile dataset_RGB.py
# ghi đè toàn bộ nội dung mới
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])

class DataLoaderTrain(Dataset):
    def __init__(self, rgb_dir, img_options=None):
        super(DataLoaderTrain, self).__init__()

        inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
        tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'output')))

        self.inp_filenames = [os.path.join(rgb_dir, 'input', x)  for x in inp_files if is_image_file(x)]
        self.tar_filenames = [os.path.join(rgb_dir, 'output', x) for x in tar_files if is_image_file(x)]

        self.img_options = img_options
        self.sizex       = len(self.tar_filenames)  # get the size of target

        self.ps = self.img_options['patch_size']

    def __len__(self):
        return self.sizex

    def __getitem__(self, index):
        index_ = index % self.sizex
        ps = self.ps

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        w,h = tar_img.size
        padw = ps-w if w<ps else 0
        padh = ps-h if h<ps else 0

        # Reflect Pad in case image is smaller than patch_size
        if padw!=0 or padh!=0:
            inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
            tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        hh, ww = tar_img.shape[1], tar_img.shape[2]

        rr     = random.randint(0, hh-ps)
        cc     = random.randint(0, ww-ps)
        aug    = random.randint(0, 8)

        # Crop patch
        inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
        tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]

        # Data Augmentations
        if aug==1:
            inp_img = inp_img.flip(1)
            tar_img = tar_img.flip(1)
        elif aug==2:
            inp_img = inp_img.flip(2)
            tar_img = tar_img.flip(2)
        elif aug==3:
            inp_img = torch.rot90(inp_img,dims=(1,2))
            tar_img = torch.rot90(tar_img,dims=(1,2))
        elif aug==4:
            inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
            tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
        elif aug==5:
            inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
            tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
        elif aug==6:
            inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
            tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
        elif aug==7:
            inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
            tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))
        
        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename

class DataLoaderVal(Dataset):
    def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
        super(DataLoaderVal, self).__init__()

        inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
        tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'output')))

        self.inp_filenames = [os.path.join(rgb_dir, 'input', x)  for x in inp_files if is_image_file(x)]
        self.tar_filenames = [os.path.join(rgb_dir, 'output', x) for x in tar_files if is_image_file(x)]

        self.img_options = img_options
        self.sizex       = len(self.tar_filenames)  # get the size of target

        self.ps = self.img_options['patch_size']

    def __len__(self):
        return self.sizex

    def __getitem__(self, index):
        index_ = index % self.sizex
        ps = self.ps

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        # Validate on center crop
        if self.ps is not None:
            inp_img = TF.center_crop(inp_img, (ps,ps))
            tar_img = TF.center_crop(tar_img, (ps,ps))

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename

class DataLoaderTest(Dataset):
    def __init__(self, inp_dir, img_options):
        super(DataLoaderTest, self).__init__()

        inp_files = sorted(os.listdir(inp_dir))
        self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]

        self.inp_size = len(self.inp_filenames)
        self.img_options = img_options

    def __len__(self):
        return self.inp_size

    def __getitem__(self, index):

        path_inp = self.inp_filenames[index]
        filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
        inp = Image.open(path_inp)

        inp = TF.to_tensor(inp)
        return inp, filename


Overwriting dataset_RGB.py


In [8]:
%%writefile train.py

import os
from config import Config 
opt = Config('training.yml')

gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus

import torch
torch.backends.cudnn.benchmark = False

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import random
import time
import numpy as np

import utils
from data_RGB import get_training_data, get_validation_data
from MPRNet import MPRNet
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from pdb import set_trace as stx

######### Set Seeds ###########
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)

start_epoch = 1
mode = opt.MODEL.MODE
session = opt.MODEL.SESSION

result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
model_dir  = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models',  session)

utils.mkdir(result_dir)
utils.mkdir(model_dir)

train_dir = opt.TRAINING.TRAIN_DIR
val_dir   = opt.TRAINING.VAL_DIR

######### Model ###########
model_restoration = MPRNet()
model_restoration.cuda()

device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
  print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")


new_lr = opt.OPTIM.LR_INITIAL

optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)


######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
# scheduler.step()

######### Resume ###########
# if opt.TRAINING.RESUME:
#     path_chk_rest    = utils.get_last_path(model_dir, 'model_epoch_*.pth')
#     utils.load_checkpoint(model_restoration,path_chk_rest)
#     start_epoch = utils.load_start_epoch(path_chk_rest) + 1
#     utils.load_optim(optimizer, path_chk_rest)

#     for i in range(1, start_epoch):
#         scheduler.step()
#     new_lr = scheduler.get_lr()[0]
#     print('------------------------------------------------------------------------------')
#     print("==> Resuming Training with learning rate:", new_lr)
#     print('------------------------------------------------------------------------------')

if opt.TRAINING.RESUME:
    path_chk_rest = '/kaggle/input/mprnet-checkpoints-rain14000/model_epoch_80.pth'

    if path_chk_rest is not None:
        utils.load_checkpoint(model_restoration, path_chk_rest)
        start_epoch = utils.load_start_epoch(path_chk_rest) + 1
        utils.load_optim(optimizer, path_chk_rest)

        # step scheduler đúng số epoch đã train
        for _ in range(1, start_epoch):
            scheduler.step()

        new_lr = scheduler.get_last_lr()[0]
        print('------------------------------------------------------------------------------')
        print("==> Resuming Training from epoch {} with LR {:.8f}".format(start_epoch, new_lr))
        print('------------------------------------------------------------------------------')
    else:
        print("⚠️ No checkpoint found. Training from scratch.")
        start_epoch = 1

if len(device_ids)>1:
    model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)

######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()

######### DataLoaders ###########
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)

val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS})
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)

print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
print('===> Loading datasets')

best_psnr = 31.0807
best_epoch = 0

for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    epoch_loss = 0
    train_id = 1

    model_restoration.train()
    for i, data in enumerate(tqdm(train_loader), 0):

        # zero_grad
        for param in model_restoration.parameters():
            param.grad = None

        target = data[0].cuda()
        input_ = data[1].cuda()

        restored = model_restoration(input_)
 
        # Compute loss at each stage
        loss_char = sum([criterion_char(restored[j],target) for j in range(len(restored))])
        loss_edge = sum([criterion_edge(restored[j],target) for j in range(len(restored))])
        loss = (loss_char) + (0.05*loss_edge)
        
        loss.backward()
        optimizer.step()
        epoch_loss +=loss.item()

    #### Evaluation ####
    if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0:
        model_restoration.eval()
        psnr_val_rgb = []
        for ii, data_val in enumerate((val_loader), 0):
            target = data_val[0].cuda()
            input_ = data_val[1].cuda()

            with torch.no_grad():
                restored = model_restoration(input_)
            restored = restored[0]

            for res,tar in zip(restored,target):
                psnr_val_rgb.append(utils.torchPSNR(res, tar))

        psnr_val_rgb  = torch.stack(psnr_val_rgb).mean().item()

        if psnr_val_rgb > best_psnr:
            best_psnr = psnr_val_rgb
            best_epoch = epoch
            torch.save({'epoch': epoch, 
                        'state_dict': model_restoration.state_dict(),
                        'optimizer' : optimizer.state_dict()
                        }, os.path.join(model_dir,"model_best.pth"))

        print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))

        torch.save({'epoch': epoch, 
                    'state_dict': model_restoration.state_dict(),
                    'optimizer' : optimizer.state_dict()
                    }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 

    scheduler.step()
    
    print("------------------------------------------------------------------")
    print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.8f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
    print("------------------------------------------------------------------")

    torch.save({'epoch': epoch, 
                'state_dict': model_restoration.state_dict(),
                'optimizer' : optimizer.state_dict()
                }, os.path.join(model_dir,"model_latest.pth")) 



Overwriting train.py


In [9]:
!python train.py

  _warn_get_lr_called_within_step(self)
------------------------------------------------------------------------------
==> Resuming Training from epoch 81 with LR 0.00020000
------------------------------------------------------------------------------
===> Start Epoch 81 End Epoch 101
===> Loading datasets
100%|███████████████████████████████████████| 2742/2742 [35:11<00:00,  1.30it/s]
------------------------------------------------------------------
Epoch: 81	Time: 2111.2725	Loss: 183.2656	LearningRate 0.00001935
------------------------------------------------------------------
100%|███████████████████████████████████████| 2742/2742 [35:14<00:00,  1.30it/s]
------------------------------------------------------------------
Epoch: 82	Time: 2114.3209	Loss: 183.0762	LearningRate 0.00001753
------------------------------------------------------------------
100%|███████████████████████████████████████| 2742/2742 [35:15<00:00,  1.30it/s]
-----------------------------------