In [None]:
import sys
sys.path.insert(1, '/kaggle/input/eccv-rife/')
sys.path.insert(1, '/kaggle/input/eccv-rife/ECCV_RIFE')

### Dataset

In [None]:
import os
import cv2
import torch
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset
import csv
from decord import VideoReader
from decord import cpu, gpu

In [None]:
class AccessMathDataset(Dataset):
    def __init__(self, split_name):
        self.split_name = split_name
        self.h = 224
        self.w = 224
        self.data_root = '/kaggle/input/vi-dataset-mix-224-10s/VI_dataset_mix_224_10s'
        
        
        self.csv_path = os.path.join(self.data_root, 'splits.csv')
        self.trainlist, self.testlist = [], []
        with open(self.csv_path, 'r')as f:
            reader = csv.reader(f)
            reader.__next__()
            for name, split in reader:
                if split == 'train':
                    self.trainlist += [name]
                else:
                    self.testlist += [name]
        self.load_data()
        self.all_imgs_idx = self.get_img_idx(300,7)


    def get_img_idx(self, num_frames, n):
        extreme_frames = range(0, num_frames, n)  # 300, 7
        extreme_frames = np.array(
            [[extreme_frames[i], extreme_frames[i+1]-1] for i in range(len(extreme_frames)-1)])
        gt = extreme_frames.sum(1)//2
        all_imgs = np.zeros((len(extreme_frames), 3), dtype=int)
        all_imgs[:, 0] = extreme_frames[:, 0]
        all_imgs[:, 1] = extreme_frames[:, 1]
        all_imgs[:, 2] = gt
        
        # only 9 out of 42 --- to reduce the time
        random_imgs = np.linspace(0,len(extreme_frames)-1, 9, dtype=int)
        all_imgs = all_imgs[random_imgs]
        
        all_imgs = all_imgs.flatten()  # 42*3
        return all_imgs

    def __len__(self):
        return len(self.meta_data)

    def load_data(self):
        cnt = int(len(self.trainlist) * 0.90)
        if self.split_name == 'train':
            self.meta_data = self.trainlist[:cnt]
        elif self.split_name == 'test':
            self.meta_data = self.testlist
        else:
            self.meta_data = self.trainlist[cnt:]
    

    def getimg(self, index):
        vid_path = os.path.join(self.data_root, self.meta_data[index])
        vr = VideoReader(vid_path, ctx=cpu(0))
        num_frames = len(vr)

        # Triplet: choose first frame arbitarily with a boundary of last n frames
        # 1, 2, 3, 4, 5  => 6 <= 7, 8, 9, 10, 11
        n = 7
        timestep = 0.5

        all_imgs = self.all_imgs_idx
        all_imgs = vr.get_batch(all_imgs).asnumpy()

        # shuffle along the batch dimension - can be used
        
        return all_imgs, timestep

    def __getitem__(self, index):
        all_imgs, timestep = self.getimg(index)
#         print(all_imgs.shape)
        all_imgs = torch.tensor(all_imgs, dtype=torch.uint8).permute(0,3,1,2).reshape(-1, 9, *all_imgs.shape[1:-1])
#         print(all_imgs.shape)

        return all_imgs, torch.tensor(timestep, dtype=torch.float16)

In [None]:
# path = './demo/I1_0.png'
# img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
# print(img.shape)
# plt.imshow(img)
# # plt.imshow(img[:, :, ::-1])
# plt.imshow(img[:, ::-1])
# # plt.imshow(img[::-1])

t_dataset = AccessMathDataset('train')
v_dataset = AccessMathDataset('validation')
test_dataset = AccessMathDataset('test')

dataloader = DataLoader(dataset=t_dataset, batch_size=2, shuffle=True) # pin_memory=True, num_workers=8

for imgs, timestep in dataloader:
    print(imgs.dtype)
    print(imgs.shape)
    print(timestep)
    break     
len(t_dataset), len(v_dataset), len(test_dataset), sum([len(t_dataset), len(v_dataset), len(test_dataset)])

### Train

In [None]:
import os
import cv2
import math
import time
import torch
import numpy as np
import random

from torch.utils.data import DataLoader, Dataset

# from model.RIFE import Model
# from ECCV_RIFE.model.RIFE import Model
# from ECCV_RIFE_L1.model.RIFE import Model
# from ECCV_RIFE_BN.model.RIFE import Model
# from ECCV_RIFE_drop_2.model.RIFE import Model
from ECCV_RIFE_BN_DP_3.model.RIFE import Model


from tqdm import tqdm
import pickle 
from matplotlib import pyplot as plt 


In [None]:
max_epochs = 60
batch_size = 10
local_rank = 0 # used for distributed gpu, so use -1 or dont use at all
world_size = 4
num_workers = 4
pin_memory = True

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_log_path = 'new_train_log'
val_log_path = 'new_val_log'
os.makedirs(train_log_path,exist_ok=True)
os.makedirs(val_log_path,exist_ok=True)



# def get_learning_rate(step, step_per_epoch):
#     if step < 2000:
#         mul = step / 2000.
#         return 3e-4 * mul
#     else:
#         mul = np.cos((step - 2000) / (max_epochs *
#                      step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
#         return (3e-4 - 3e-6) * mul + 3e-6

def get_learning_rate(step, step_per_epoch):
    return 5e-5
    
def flow2rgb(flow_map_np):
    h, w, _ = flow_map_np.shape
    rgb_map = np.ones((h, w, 3)).astype(np.float32)
    normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())

    rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
    rgb_map[:, :, 1] -= 0.5 * \
        (normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
    rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
    return rgb_map.clip(0, 1)


def train(model):
    train_losses, val_losses = [], []
    min_val_loss, es, patience, max_psnr = 999,0,7, 0
    step = 0
    nr_eval = 0
    dataset = AccessMathDataset('train')
    train_data = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)#, drop_last=True)
#     train_data = DataLoader(dataset, batch_size=batch_size)#, num_workers=8, pin_memory=True, drop_last=True)
    step_per_epoch = train_data.__len__()
    dataset_val = AccessMathDataset('validation')
    val_data = DataLoader(dataset_val, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers)
#     val_data = DataLoader(dataset_val, batch_size=batch_size)#, pin_memory=True, num_workers=8)
    print('training...')
    time_stamp = time.time()
    for epoch in range(max_epochs):
        pbar = tqdm(enumerate(train_data), total=len(train_data), desc=f"Epoch{epoch}, Step", position=0)
        train_loss_all = []
        for i, data in pbar:
            data_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            data_gpu, timestep = data
            data_gpu = data_gpu.to(device) / 255.
            timestep = timestep.to(device)
            
            data_gpu = data_gpu.reshape(-1, *data_gpu.shape[2:]) # added
            
            imgs = data_gpu[:, :6]
            gt = data_gpu[:, 6:9]
            learning_rate = get_learning_rate(step, step_per_epoch) * world_size / 4
            #TODO: pass timestep if you are training RIFEm
            pred, info = model.update(imgs, gt, learning_rate, training=True)
            train_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            # if step % 200 == 1:
            #     print(f"Step=>{step} | learning_rate={learning_rate} | loss/l1={info['loss_l1']} | loss/tea={info['loss_tea']} | loss/distill={info['loss_distill']}")
            if (step>200) and (step % 200 == 0):
                gt = (gt.permute(0, 2, 3, 1).detach(
                ).cpu().numpy() * 255).astype('uint8')
                mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(
                    0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                pred = (pred.permute(0, 2, 3, 1).detach(
                ).cpu().numpy() * 255).astype('uint8')
                merged_img = (info['merged_tea'].permute(
                    0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
                flow1 = info['flow_tea'].permute(
                    0, 2, 3, 1).detach().cpu().numpy()
                for i in range(2):
                    imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[
                        :, :, ::-1]
                    cv2.imwrite(os.path.join(train_log_path, f"img{i}_merged_{step}.jpeg"), imgs)
                    cv2.imwrite(os.path.join(train_log_path, f"img{i}_flow_{step}.jpeg"), np.concatenate(
                        (flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1))
                    cv2.imwrite(os.path.join(
                        train_log_path, f"img{i}_mask_{step}.jpeg"), mask[i])
           
            # pbar.set_description(f"Epoch {i}")
            pbar.set_postfix({'lr':learning_rate,'l_l1': info['loss_l1'].detach().item(), 'l_tea':info['loss_tea'].detach().item(), 'l_dist':info['loss_distill'].detach().item()})
            
            # print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i,step_per_epoch, data_time_interval, train_time_interval, info['loss_l1']))
            step += 1
            train_loss_all.append(info['loss_all'])
        
        train_losses.append(np.array(train_loss_all).mean())
        
        nr_eval += 1
        c_psnr, val_loss = evaluate(model, val_data, step)
        val_losses.append(val_loss)
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            model.save_model(train_log_path)
            es = 0
        elif val_loss > min_val_loss:
            es += 1
            if es == patience:
                print("Early stopping at epoch", epoch)
                break
        if max_psnr < c_psnr:
            max_psnr = c_psnr
            model.save_model(val_log_path)

            
        if epoch>=50 and epoch%50==0:
            model_save_path = f"model_after_epoch{epoch}"
            os.makedirs(model_save_path,exist_ok=True)
            model.save_model(model_save_path)

    # plots
    train_losses, val_losses = np.array(train_losses), np.array(val_losses)
    plt.plot(train_losses)
    plt.plot(val_losses)
    plt.legend(['Train', 'Val'])
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    
    # saving losses
    with open(f'{train_log_path}/losses.pkl', 'wb') as f:
        pickle.dump({'train_losses':train_losses, 'val_losses':val_losses}, f)

def evaluate(model, val_data, nr_eval, is_tqdm=False):
    loss_l1_list = []
    loss_distill_list = []
    loss_tea_list = []
    val_loss_all = []
    psnr_list = []
    psnr_list_teacher = []
    time_stamp = time.time()
    if is_tqdm:
        pbar = tqdm(enumerate(val_data), total = len(val_data))
    else:
        pbar = enumerate(val_data)
    for i, data in pbar:
#   for i, data in enumerate(val_data):

        data_gpu, timestep = data
        data_gpu = data_gpu.to(device) / 255.
            
        data_gpu = data_gpu.reshape(-1, *data_gpu.shape[2:]) # added
        
        imgs = data_gpu[:, :6]
        gt = data_gpu[:, 6:9]
        with torch.no_grad():
            pred, info = model.update(imgs, gt, training=False)
            merged_img = info['merged_tea']
        loss_l1_list.append(info['loss_l1'].cpu().numpy())
        loss_tea_list.append(info['loss_tea'].cpu().numpy())
        loss_distill_list.append(info['loss_distill'].cpu().numpy())
        val_loss_all.append(info['loss_all'])
        for j in range(gt.shape[0]):
            psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
            psnr_list.append(psnr)
            psnr = -10 *math.log10(torch.mean((merged_img[j] - gt[j]) * (merged_img[j] - gt[j])).cpu().data)
            psnr_list_teacher.append(psnr)
    eval_time_interval = time.time() - time_stamp

    print(f"nr_eval => {nr_eval} | psnr={np.array(psnr_list).mean()} | psnr_teacher={np.array(psnr_list_teacher).mean()}")
    return np.array(psnr_list).mean(), np.array(val_loss_all).mean()

In [None]:


model = Model()

# finetune
pretrained_model = '/kaggle/input/rife-trained-v6'
model.load_model(pretrained_model, -1)
# model.flownet.load_state_dict(torch.load(pretrained_model))

train(model)

# dataset = AccessMathDataset('train')
# train_data = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)#, drop_last=True)
# dataset_val = AccessMathDataset('validation')
# val_data = DataLoader(dataset_val, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers)
# dataset_test = AccessMathDataset('test')
# test_data = DataLoader(dataset_test, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers)


# for data in tqdm(train_data):
#     pass
# for data in tqdm(val_data):
#     pass
# for data in tqdm(test_data):
#     pass

### Test

In [None]:
# asd
model_path = "/kaggle/working/new_train_log/flownet.pkl"
# model_path = '/kaggle/input/flownet/pytorch/default/1/flownet.pkl'
model = Model()
# model.load_model(model_path, -1)
model.flownet.load_state_dict(torch.load(model_path))
dataset_test = AccessMathDataset('test')
print(len(dataset_test))
test_data = DataLoader(dataset_test, batch_size=batch_size,
                        pin_memory=pin_memory, num_workers=num_workers)
evaluate(model, test_data, 0)