# Import

In [None]:
import torch
import torchvision
import cv2 
import numpy as np
from pathlib import Path
import sys
from pathlib import Path
from IPython.display import clear_output
import matplotlib.pyplot as plt
import torchvision.models as models

WORK_DIR = Path(Path.cwd()).parent
sys.path.append(str(WORK_DIR))
from src import ROOT
from src.datasets.transforms import *
from src.utils import *

# Setup

In [None]:
cfg = {
    'exp': 'exp1',
    'img_root': Path(ROOT)/'EPIC_KITCHENS_2018'/'EK_frames',
    'img_tmpl': 'img_{:05d}.jpg',
    'img_rsz': 480,
    'crop_sz': 80,
    'seq_idx': 2,
    'chosen_pt': (200, 280),
    'len_data': 128,
    'jitter': 20,
    'batch_size': 32,
    'num_workers': 8,
    'show_pos_prob': 0.5,
}

In [None]:
img_root = cfg['img_root']
img_tmpl = cfg['img_tmpl']
img_rsz = cfg['img_rsz']
crop_sz = cfg['crop_sz']
seq_idx = cfg['seq_idx']
chosen_pt = cfg['chosen_pt']
exp = cfg['exp']

In [None]:
with open(Path(ROOT)/'mlcv-exp'/'data'/'labels'/'ek_ar_seq_train.txt') as f:
    img_list = f.read().splitlines()[cfg['seq_idx']]
    
img_path   = img_list.split(' ')[0]
path_length = int(img_list.split(' ')[1])

img_list = []
for i in range(path_length):
    img_list.append(img_root/img_path/img_tmpl.format(i))

# Choose First Point

In [None]:
img0 = cv2.imread(str(img_root/img_list[0]))[:, :, ::-1]
img_list.append(img_root/img_path/img_tmpl.format(i))
img0 = cv2.resize(img0, (img_rsz, img_rsz))

pad = crop_sz
x_cen = chosen_pt[1]
y_cen = chosen_pt[0]
bbox_w = crop_sz
bbox_h = crop_sz
chosen_point_bbox = np.asarray([x_cen, y_cen, bbox_w, bbox_h])

chosen_point_bbox_xyxy = xywh2xyxy(chosen_point_bbox)
x1 = np.minimum(chosen_point_bbox_xyxy[0], chosen_point_bbox_xyxy[2])
y1 = np.minimum(chosen_point_bbox_xyxy[1], chosen_point_bbox_xyxy[3])
x2 = np.maximum(chosen_point_bbox_xyxy[0], chosen_point_bbox_xyxy[2])
y2 = np.maximum(chosen_point_bbox_xyxy[1], chosen_point_bbox_xyxy[3])

chosen_crop = img0[y1:y2, x1:x2]

fig, ax = plt.subplots()
ax.imshow(chosen_crop)
plt.show()

# Random crop

In [None]:
# rand_row = np.random.randint(0, img0.shape[1])
# rand_col = np.random.randint(0, img0.shape[0])
rand_row = chosen_pt[0] + np.random.randint(-10, 10)
rand_col = chosen_pt[1] + np.random.randint(-10, 10)
x_cen = rand_col
y_cen = rand_row
bbox_w = crop_sz
bbox_h = crop_sz
random_choice_bbox = np.asarray([x_cen, y_cen, bbox_w, bbox_h])

fig, ax = plt.subplots()
ax.imshow(img0)
draw_bbox(ax, chosen_point_bbox, 'b')
draw_bbox(ax, random_choice_bbox, 'r')
plt.show()

print(bbox_iou(random_choice_bbox, chosen_point_bbox))

In [None]:
img0_pad = np.asarray([np.pad(img0[:, :, x], pad_width=pad, mode='constant', constant_values=0) for x in range(3)])
img0_pad = np.swapaxes(img0_pad, 0, 1)
img0_pad = np.swapaxes(img0_pad, 1, 2)

random_choice_bbox_pad = random_choice_bbox.copy()
random_choice_bbox_pad[0] = random_choice_bbox_pad[0] + pad
random_choice_bbox_pad[1] = random_choice_bbox_pad[1] + pad
random_choice_bbox_pad[2] = random_choice_bbox_pad[2]
random_choice_bbox_pad[3] = random_choice_bbox_pad[3]

chosen_bbox_pad = chosen_point_bbox.copy()
chosen_bbox_pad[0] = chosen_bbox_pad[0] + pad
chosen_bbox_pad[1] = chosen_bbox_pad[1] + pad
chosen_bbox_pad[2] = chosen_bbox_pad[2]
chosen_bbox_pad[3] = chosen_bbox_pad[3]

random_choice_bbox_xyxy = xywh2xyxy(random_choice_bbox)
x1 = np.minimum(random_choice_bbox_xyxy[0], random_choice_bbox_xyxy[2]) + pad
y1 = np.minimum(random_choice_bbox_xyxy[1], random_choice_bbox_xyxy[3]) + pad
x2 = np.maximum(random_choice_bbox_xyxy[0], random_choice_bbox_xyxy[2]) + pad
y2 = np.maximum(random_choice_bbox_xyxy[1], random_choice_bbox_xyxy[3]) + pad

color_mask = np.zeros((img0_pad.shape[0], img0_pad.shape[1], 3))
color_mask[y1:y2, x1:x2, 1] = 1
rand_crop = img0_pad[y1:y2, x1:x2]

fig, ax = plt.subplots()
ax.imshow(img0_pad)
ax.imshow(color_mask, alpha=0.5)
draw_bbox(ax, chosen_bbox_pad, 'b')
draw_bbox(ax, random_choice_bbox_pad, 'r')
plt.show()

fig, ax = plt.subplots()
ax.imshow(rand_crop)
plt.show()

# Dataloader

In [None]:
import time

class Train_Dataset(torch.utils.data.Dataset):
    def __init__(self, cfg, img_path):
        super().__init__()
        rsz = cfg['img_rsz']
        img_root = cfg['img_root']
        self.crop_size = cfg['crop_sz']
        self.chosen_point = cfg['chosen_pt']
        self.rand = cfg['show_pos_prob']
        self.len_data = cfg['len_data']
        self.jitter = cfg['jitter']
        
        img = cv2.imread(img_path)[:, :, ::-1]
        img = cv2.resize(img, (rsz, rsz))
        self.img_shape = img.shape[0]
        
        img_pad = np.asarray([np.pad(img[:, :, x], pad_width=self.crop_size, mode='constant', constant_values=0) 
                              for x in range(3)])
        img_pad = np.swapaxes(img_pad, 0, 1)
        img_pad = np.swapaxes(img_pad, 1, 2)
        self.img_pad = img_pad
        
        x_cen = self.chosen_point[1]
        y_cen = self.chosen_point[0]
        bbox_w = self.crop_size
        bbox_h = self.crop_size
        self.chosen_point_bbox = np.asarray([x_cen, y_cen, bbox_w, bbox_h])

        tfrm = []
        tfrm.append(ImgToTorch())
        self.transform = torchvision.transforms.Compose(tfrm)
        
    def __getitem__(self, index):
        np.random.seed(int(time.time()) + index)
        rand = np.random.uniform()
        if rand > self.rand:
            rand_row = np.random.randint(0, self.img_shape)
            rand_col = np.random.randint(0, self.img_shape)
        else:
            jitter = np.random.randint(-self.jitter, self.jitter) if self.jitter else 0
            rand_row = self.chosen_point[0] + jitter
            rand_col = self.chosen_point[1] + jitter
                
        x_cen = rand_col
        y_cen = rand_row
        bbox_w = self.crop_size
        bbox_h = self.crop_size
        random_choice_bbox = np.asarray([x_cen, y_cen, bbox_w, bbox_h])
        random_choice_bbox_xyxy = xywh2xyxy(random_choice_bbox)
        x1 = np.minimum(random_choice_bbox_xyxy[0], random_choice_bbox_xyxy[2]) + self.crop_size
        y1 = np.minimum(random_choice_bbox_xyxy[1], random_choice_bbox_xyxy[3]) + self.crop_size
        x2 = np.maximum(random_choice_bbox_xyxy[0], random_choice_bbox_xyxy[2]) + self.crop_size
        y2 = np.maximum(random_choice_bbox_xyxy[1], random_choice_bbox_xyxy[3]) + self.crop_size
        rand_crop = self.img_pad[y1:y2, x1:x2]

        crop = rand_crop
        iou = bbox_iou(random_choice_bbox, self.chosen_point_bbox)

        sample      = {'img': crop}
        sample      = self.transform(sample)
        crop        = sample['img']
        return crop, iou
    
    def __len__(self):
            return self.len_data

In [None]:
kwargs = {
    'batch_size'    : cfg['batch_size'],
    'shuffle'       : True,
    'num_workers'   : cfg['num_workers'],
    'sampler'       : None,
    'pin_memory'    : True
}

train_dataloader = torch.utils.data.DataLoader(Train_Dataset(cfg, str(img_root/img_list[0])), **kwargs)

# Model

In [None]:
model = models.resnet50(pretrained=True)
resnet_feature_layers = ['conv1',
                         'bn1',
                         'relu',
                         'maxpool',
                         'layer1',
                         'layer2',
                         'layer3',
                         'layer4']
last_layer = 'layer3'
last_layer_idx = resnet_feature_layers.index(last_layer)
resnet_module_list = [model.conv1,
                      model.bn1,
                      model.relu,
                      model.maxpool,
                      model.layer1,
                      model.layer2,
                      model.layer3,
                      model.layer4]
resnet_model = torch.nn.Sequential(*resnet_module_list[:last_layer_idx+1])

class Full_Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet_model = resnet_model

        self.lin_out = torch.nn.Sequential(
            torch.nn.Linear(25600, 512),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(512, 256),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(256, 1))
        
    def forward(self, x):
        out = self.resnet_model(x)
        out = out.view(out.shape[0], 25600)
        out = self.lin_out(out)
        return out
    
model = Full_Model()
model = model.cuda()

# Optimizer

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Train First Frame

In [None]:
from livelossplot import PlotLosses

liveloss = PlotLosses()
bce = torch.nn.BCEWithLogitsLoss()
mse = torch.nn.MSELoss()
max_epoch = 30
model = model.train()
best_loss = 1000
for epoch in range(max_epoch):
    logs = {}
    running_loss = 0.0
    for step, (img, iou) in enumerate(train_dataloader):
#         fig, ax = plt.subplots()
#         img = ImgToNumpy()(img)
#         ax.imshow(img[0])
#         plt.show()
#         print(iou)
        
        img = img.cuda()
        iou = iou.type(torch.FloatTensor).cuda()
        
        out = model(img)
        
        loss = mse(out[:, 0], iou)
        
        running_loss += loss.item()
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    running_loss /= len(train_dataloader)
    logs['loss'] = running_loss
    if running_loss < best_loss:
        best_loss = running_loss
        state = {'epoch': epoch, 
                 'model_state_dict': model.state_dict(), 
                 'optimizer_state_dict': optimizer.state_dict()}
        torch.save(state, Path(ROOT)/'mlcv-exp'/'data'/'weights'/'model_apt_{}.state'.format(exp))
    liveloss.update(logs)
    liveloss.draw()

# Load

In [None]:
# print(best_loss)
map_loc = 'cuda:0' 
load_dir = Path(ROOT)/'mlcv-exp'/'data'/'weights'/'model_apt_{}.state'.format(exp)
ckpt = torch.load(load_dir, map_location=map_loc)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])

# Detect One

In [None]:
tfrm = []
tfrm.append(ImgToTorch())
transform = torchvision.transforms.Compose(tfrm)
steps = 20

In [None]:
img0 = cv2.imread(str(img_root/img_list[1]))[:, :, ::-1]
img0 = cv2.resize(img0, (img_rsz, img_rsz))

fig, ax = plt.subplots()
ax.imshow(img0)
plt.show()

img_crops = []
pos_list = []
for row in range(0, img0.shape[1], steps):
    for col in range(0, img0.shape[0], steps):
        if img0[row:row+crop_sz, col:col+crop_sz].shape == (80, 80, 3):
            img_crops.append(img0[row:row+crop_sz, col:col+crop_sz])
            pos_list.append((row, col))

crop_rows = img0.shape[1]//steps        
crop_cols = img0.shape[0]//steps
fig, ax = plt.subplots(crop_rows, crop_cols)
idx = 0

for i in range(crop_rows):
    for j in range(crop_cols):
        if idx >= len(img_crops):
            break
        ax[i, j].imshow(img_crops[idx])
        ax[i, j].axis('off')
        idx += 1
plt.show()

In [None]:
model = model.eval()
pred_iou = []
with torch.no_grad():
    for crop in img_crops:
        img = crop.copy()
        sample      = {'img': img}
        sample      = transform(sample)
        img         = sample['img']
        img         = img.unsqueeze(0)
        img = img.cuda()
        out = model(img)
        out = out[0]
        pred_iou.append(out[0])

In [None]:
print(np.argmax(pred_iou))

pos = pos_list[np.argmax(pred_iou)]

color_mask = np.zeros((img0.shape[0], img0.shape[1], 3))
color_mask[pos[0]:pos[0]+crop_sz, pos[1]:pos[1]+crop_sz, 1] = 1

fig, ax = plt.subplots()
ax.imshow(img0)
ax.imshow(color_mask, alpha=0.5)
plt.show()

# Detect Video

In [None]:
from moviepy.editor import ImageSequenceClip
from tqdm import tqdm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from IPython.display import Image as IPythonImage

fps = 12

In [None]:
frames = []
for frame in tqdm(img_list):
    img0 = cv2.imread(str(img_root/frame))[:, :, ::-1]
    img0 = cv2.resize(img0, (img_rsz, img_rsz))

    img_crops = []
    pos_list = []
    for row in range(0, img0.shape[1], steps):
        for col in range(0, img0.shape[0], steps):
            if img0[row:row+crop_sz, col:col+crop_sz].shape == (80, 80, 3):
                img_crops.append(img0[row:row+crop_sz, col:col+crop_sz])
                pos_list.append((row, col))
                
    model = model.eval()
    pred_iou = []
    with torch.no_grad():
        for crop in img_crops:
            img = crop.copy()
            sample      = {'img': img}
            sample      = transform(sample)
            img         = sample['img']
            img         = img.unsqueeze(0)
            img = img.cuda()
            out = model(img)[0]
            pred_iou.append(out[0])
            
    pos = pos_list[np.argmax(pred_iou)]
    
    color_mask = np.zeros((img0.shape[0], img0.shape[1], 3))
    color_mask[pos[0]:pos[0]+crop_sz, pos[1]:pos[1]+crop_sz, 1] = 1

    fig, ax = plt.subplots()
    ax.axis('off')
    ax = fig.gca()
    ax.imshow(img0)
    ax.imshow(color_mask, alpha=0.5)
    
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    frames.append(data)
    plt.close()
    
segment_clip = ImageSequenceClip(frames, fps=fps)
name = str(Path(ROOT)/'mlcv-exp/data/saved'/'model_apt_{}.gif'.format(exp))
segment_clip.write_gif(name, fps=fps)

with open(name, 'rb') as f:
    display(IPythonImage(data=f.read(), format='png'))

# Continuous Learning

In [None]:
test_cfg = {
    'img_root': Path(ROOT)/'EPIC_KITCHENS_2018'/'EK_frames',
    'img_tmpl': 'img_{:05d}.jpg',
    'img_rsz': 480,
    'crop_sz': 80,
    'len_data': 128,
    'jitter': 0,
    'batch_size': 32,
    'num_workers': 8,
    'show_pos_prob': 0.5,
}

In [None]:
def train_model(model, optimizer, pred_frame_pts, img_path, max_epoch):
    kwargs = {
        'batch_size'    : test_cfg['batch_size'],
        'shuffle'       : True,
        'num_workers'   : test_cfg['num_workers'],
        'sampler'       : None,
        'pin_memory'    : True
    }
    
    model = model.train()
    mse = torch.nn.MSELoss()
    
    test_cfg['chosen_pt'] = pred_frame_pts
    
    dataloader = torch.utils.data.DataLoader(Train_Dataset(test_cfg, img_path), **kwargs)
    for epoch in range(max_epoch):
        for step, (img, iou) in enumerate(dataloader):
#             fig, ax = plt.subplots()
#             img = ImgToNumpy()(img)
#             ax.imshow(img[0])
#             plt.show()
#             print(iou)

            img = img.cuda()
            iou = iou.type(torch.FloatTensor).cuda()

            out = model(img)

            loss = mse(out[:, 0], iou)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [None]:
from moviepy.editor import ImageSequenceClip
from tqdm import tqdm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from IPython.display import Image as IPythonImage

fps = 12

frames = []
for frame in tqdm(img_list):
    img0 = cv2.imread(str(img_root/frame))[:, :, ::-1]
    img0 = cv2.resize(img0, (rsz, rsz))

    img_crops = []
    pos_list = []
    for row in range(0, img0.shape[1], steps):
        for col in range(0, img0.shape[0], steps):
            if img0[row:row+crop_sz, col:col+crop_sz].shape == (80, 80, 3):
                img_crops.append(img0[row:row+crop_sz, col:col+crop_sz])
                pos_list.append((row, col))

    model = model.eval()
    pred_iou = []
    

    with torch.no_grad():
        for crop in img_crops:
            img = crop.copy()
            sample      = {'img': img}
            sample      = transform(sample)
            img         = sample['img']
            img         = img.unsqueeze(0)
            img = img.cuda()
            out = model(img)[0]
            pred_iou.append(out[0])
            
    pos = pos_list[np.argmax(pred_iou)]
    pos_cen = np.asarray(pos) + crop_sz//2
    train_model(model, optimizer, pos_cen, str(img_root/frame), 1)
    
    color_mask = np.zeros((img0.shape[0], img0.shape[1], 3))
    color_mask[pos[0]:pos[0]+crop_sz, pos[1]:pos[1]+crop_sz, 1] = 1

    fig, ax = plt.subplots()
    ax.axis('off')
    ax = fig.gca()
    ax.imshow(img0)
    ax.imshow(color_mask, alpha=0.5)
    
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    frames.append(data)
    plt.close()
    
segment_clip = ImageSequenceClip(frames, fps=fps)
name = str(Path(ROOT)/'mlcv-exp/data/saved'/'model_apt_1.gif')
segment_clip.write_gif(name, fps=fps)

with open(name, 'rb') as f:
    display(IPythonImage(data=f.read(), format='png'))

In [None]:
state = {'epoch': epoch, 
         'model_state_dict': model.state_dict(), 
         'optimizer_state_dict': optimizer.state_dict()}
torch.save(state, Path(ROOT)/'mlcv-exp'/'data'/'weights'/'model_apt_5.state')