# Task 3: Helper notebook for loading the data and saving the predictions

In [1]:
%load_ext autoreload
%autoreload 2
import pickle
import gzip
import numpy as np
import os
import cv2
from copy import deepcopy


from mvseg.mvseg.utils.experiments import load_experiment
import glob
import torch
from mvseg.mvseg.datasets import get_dataset

import argparse
from pathlib import Path
import signal
import shutil
import re
import os
import copy
from collections import defaultdict
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from mvseg.mvseg.datasets import get_dataset
from mvseg.settings import TRAINING_PATH
from mvseg import logger
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from skimage.transform import resize


### Helper functions

In [2]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

In [3]:
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)

### Load data, make predictions and save prediction in correct format

In [4]:
# load data
train_data = load_zipped_pickle("/cluster/project/infk/cvg/students/alpaul/MitralValveSegmentation/data/train.pkl")
test_data = load_zipped_pickle("/cluster/project/infk/cvg/students/alpaul/MitralValveSegmentation/data/test.pkl")
# samples = load_zipped_pickle("sample.pkl")

## Get basic dataset info

In [25]:
print(f"Train data info:")
print(f"    Number of train videos: {np.array(train_data).shape[0]}")
print(f"    Min num frames in train videos: {min([data['video'].shape[2] for data in train_data])}")
print(f"    Max num frames in train videos: {max([data['video'].shape[2] for data in train_data])}")
print(f"    Average num of frames in train videos: {np.mean([data['video'].shape[2] for data in train_data])}")
# print([np.array(train_data[i]['box']).shape for i in range(len(train_data))])

print(f"Test data info:")
print(f"    Number of test videos: {np.array(test_data).shape[0]}")
print(f"    Min num frames in test videos: {min([data['video'].shape[2] for data in test_data])}")
print(f"    Max num frames in test videos: {max([data['video'].shape[2] for data in test_data])}")
print(f"    Average num of frames in test videos: {np.mean([data['video'].shape[2] for data in test_data])}")
print(f"Shapes: {[data['video'].shape for data in test_data]}")

Train data info:
    Number of train videos: 65
    Min num frames in train videos: 54
    Max num frames in train videos: 334
    Average num of frames in train videos: 151.83076923076922
Test data info:
    Number of test videos: 20
    Min num frames in test videos: 39
    Max num frames in test videos: 125
    Average num of frames in test videos: 75.35
Shapes: [(586, 821, 103), (587, 791, 52), (583, 777, 69), (582, 851, 61), (732, 845, 53), (583, 809, 84), (582, 737, 78), (587, 775, 125), (730, 956, 76), (587, 781, 104), (583, 681, 68), (587, 713, 90), (587, 612, 78), (587, 773, 73), (707, 855, 39), (731, 1007, 72), (583, 780, 106), (583, 670, 63), (594, 745, 51), (583, 779, 62)]


## Create videos for visualization

In [None]:
num_videos_to_save = 65

train = deepcopy(train_data)
for i in range(num_videos_to_save):
    video = train[i]['video']
    box = np.array(train[i]['box'])
    label = train[i]['label']
    label = label.astype('float32')
    print(label)
    break
#     label = label*255
#     print(list(label[train[i]['frames'][0]]))
    label = np.moveaxis(label, -1, 0)
    video = np.moveaxis(video, -1, 0)
    first_image = video[0]
    size = first_image.shape
    fps = 10
    out = cv2.VideoWriter(f'data/train_videos/{i}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (size[1], size[0]), False)
    for j in range(len(video)):
        image = video[j]
#         image[box] = 255
        labelj = label[j]
        image[labelj] = 255
        box = box.astype('uint8')*255
        label = label.astype('uint8')*255
        out.write(image)
    out.release()

In [48]:
import torch
## Create inference videos

# Load model

experiments = ['exp1_finetune_1_head'] # ['exp1_finetune_2_heads'] # ['exp1_box_gpu']
all_experiments = {}
experiment = experiments[0]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

conf = '/cluster/home/alpaul/MitralValveSegmentation/mvseg/mvseg/configs/config_train.yaml'

logger.info(f'Starting test {experiment}')
output_dir = Path(TRAINING_PATH, experiment)
#     conf = OmegaConf.load(conf)
conf = OmegaConf.merge(OmegaConf.load(conf), {'train':{'num_workers': 0}})
data_conf = conf.data
dataset = get_dataset(data_conf.name)(data_conf)
test_loader = dataset.get_data_loader('test')
model = load_experiment(experiment, conf.model)
model = model.eval()

loss_fn, metrics_fn = model.loss, model.metrics
viz = []
all_experiments[experiment] = []


#         all_experiments[experiment].append({
#                         'L1_roll': error_roll.cpu().item(),
#                         'L1_pitch': error_pitch,
#                         'L1_fov': metrics['fov/L1_degree_loss'].cpu().item(),
#                         'name': name,
#                         **{'gt_'+str(gt_key):gts[gt_key].unsqueeze(0).cpu().item()  
#                            if isinstance(gts[gt_key],torch.Tensor)
#                            else gts[gt_key]
#                            for gt_key in gts},
#                         **{'pred_'+str(pred_key):preds[pred_key].unsqueeze(0).cpu().item()  
#                            if isinstance(preds[pred_key],torch.Tensor)
#                            else preds[pred_key]
#                            for pred_key in preds}
#                         })


# Read all frames into list
# resize if needed
# feed through and get box output
# if resized, resize back.
# mask the image frame with the boxoutput and save to video



[12/17/2022 14:34:44 mvseg INFO] Starting test exp1_finetune_1_head
[12/17/2022 14:34:44 mvseg.mvseg.datasets.base_dataset INFO] Creating dataset MVSegDataset
[12/17/2022 14:35:09 mvseg.mvseg.utils.experiments INFO] Loading checkpoint checkpoint_best.tar


In [88]:
# Box only
from copy import deepcopy

# count = 0
fps = 10
# out = cv2.VideoWriter(f'full_test_resize_WORKING.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (821, 586), False) # compatible 821, 586


all_preds = {}
all_images = {}

count = 0
for data in tqdm(test_loader, desc='Testing', ascii=True, disable=False):
    count += 1
    
    with torch.no_grad():   
        pred = model(data)
        box_pred = pred['box_seg'].squeeze(0).squeeze(0)
        h,w = data['hw']
#         if not (h[0] == 586 and w[0] == 821):
#             print("Skipping")
#             continue
        box_pred = torch.tensor(resize(box_pred, (h[0],w[0]), anti_aliasing=True))
        pred_box_mask = torch.gt(box_pred, 0.8)
        pred_box_mask = pred_box_mask.numpy().astype('bool')
#         im = data['image'].squeeze(0).squeeze(0)
#         im = resize(im, (h[0],w[0]), anti_aliasing=True)
#         print(im.shape, pred_box_mask.shape)
#         print(type(im), type(pred_box_mask))
#         assert im.shape == pred_box_mask.shape
#         im = im * 255 * 255
        
#         im[pred_box_mask] = 255
        
#         im = im.astype('uint8')
#         print(np.max(im), np.max(pred_box_mask))
#         print(im)
#         plt.imshow(im)


        if data['video'][0] not in all_preds.keys():
            all_preds[data['video'][0]] = [pred_box_mask]
#             all_images[data['video'][0]] = [im]
        else:
            all_preds[data['video'][0]].append(pred_box_mask)
#             all_images[data['video'][0]].append(im)
#         out.write(im)
#         if count >= 100:
#             break
            
            
#         print(type(im))
#         out.write(pred_box_mask.numpy().astype('uint8'))
        
#         plt.imshow(pred_box_mask.numpy().astype('uint8'))

# out.release()



  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)


  return torch.stack(batch, 0, out=out)
Testing: 100%|###########################################################################| 1507/1507 [01:50<00:00, 13.60it/s]


In [55]:
# Arrange it for prediction

print(all_preds)

{'E9AHVWGBUF': [tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [Fals

In [65]:
# Sem only
# count = 0
# fps = 10
# out = cv2.VideoWriter(f'test_1.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (112, 112), False)
# for data in tqdm(test_loader, desc='Testing', ascii=True, disable=False):
#     count += 1
#     with torch.no_grad():
# #         print(data)        
#         pred = model(data)
#         sem_pred = pred['sem_seg'].squeeze(0).squeeze(0)
# #         print(torch.max(box_pred, dim=0))
#         pred_sem_mask = torch.gt(sem_pred, 0.5)
#         im = data['image'].squeeze(0).squeeze(0)
        
#         im = im.numpy() * 255 * 255
# #         print(np.max(im, axis=0))
# #         im = im.astype('uint8')
#         im[pred_sem_mask] = 255.
#         im = im.astype('uint8')
# #         print(type(im))
#         out.write(im)
# #         plt.imshow(im)
#         if count>=75:
#             out.release()
#             break


  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)
  return torch.stack(batch, 0, out=out)


  return torch.stack(batch, 0, out=out)
Testing:   5%|####2                                                                                 | 74/1507 [00:04<01:35, 14.99it/s]


In [89]:
# make prediction for test
predictions = []
for d in test_data:
#     print(d['video'].shape) # this is what we need to resize to.
#     prediction = np.array(np.zeros_like(d['video']), dtype=np.bool)


    prediction = np.array([list(p) for p in all_preds[d['name']]])
#     images = np.array([list(p) for p in all_images[d['name']]])
#     print(images.shape, prediction.shape)
    print(prediction.shape)
    height = prediction.shape[1]
    width = prediction.shape[2]
    out = cv2.VideoWriter(f'{d["name"]}_0.75.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height), False)
    for i in range(len(prediction)):
        mask = (np.array(prediction[i])).astype('bool') #*255.).astype('uint8')
#         print(mask)
#         print(im.shape, mask.shape)
#         im = (np.array(images[i])).astype('uint8')
#         im[mask] = 255.
#         im = np.array(images[i])
        out.write(mask.astype('uint8') * 255)
    prediction = np.moveaxis(prediction, 0, -1)
#     prediction[int(height/2)-50:int(height/2+50), int(width/2)-50:int(width/2+50)] = True
    # DATA Structure
    predictions.append({
        'name': d['name'],
        'prediction': prediction
        }
    )
    out.release()


(103, 586, 821)
(52, 587, 791)
(69, 583, 777)
(61, 582, 851)
(53, 732, 845)
(84, 583, 809)
(78, 582, 737)
(125, 587, 775)
(76, 730, 956)
(104, 587, 781)
(68, 583, 681)
(90, 587, 713)
(78, 587, 612)
(73, 587, 773)
(39, 707, 855)
(72, 731, 1007)
(106, 583, 780)
(63, 583, 670)
(51, 594, 745)
(62, 583, 779)


In [90]:
# save in correct format
save_zipped_pickle(predictions, 'my_predictions0.75.pkl')

In [95]:
# Debugging

import numpy as np


test_data = load_zipped_pickle( '/cluster/project/infk/cvg/students/alpaul/MitralValveSegmentation/data/test.pkl')
b = False
items = []
for data in test_data:
    video = np.moveaxis(data['video'], -1, 0)
    for i, im in enumerate(video):
        print(im.shape)
        continue
        if not (im.shape[0] == im.shape[1] == 112):
            im = resize(im, (112, 112), anti_aliasing=True)
            # here check what happens
            b = True
        
        im = torch.from_numpy(im / 255.).float().unsqueeze(0)
        
        # here check what happens
        print(type(im))
        items.append(im)

(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)
(586, 821)