# Task 3: Helper notebook for loading the data and saving the predictions

In [1]:
%load_ext autoreload
%autoreload 2
import pickle
import gzip
import numpy as np
import os
import cv2
from copy import deepcopy
from mvseg.mvseg.utils.experiments import load_experiment
import glob
import torch
from mvseg.mvseg.datasets import get_dataset
import argparse
from pathlib import Path
import signal
import shutil
import re
import os
import copy
from collections import defaultdict
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
from mvseg.mvseg.datasets import get_dataset
from mvseg.settings import TRAINING_PATH
from mvseg import logger
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from skimage.transform import resize


### Helper functions

In [2]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)

### Load data, make predictions and save prediction in correct format

In [3]:
# load data
train_data = load_zipped_pickle("/cluster/project/infk/cvg/students/alpaul/MitralValveSegmentation/data/train.pkl")
test_data = load_zipped_pickle("/cluster/project/infk/cvg/students/alpaul/MitralValveSegmentation/data/test.pkl")
# samples = load_zipped_pickle("sample.pkl")

## Get basic dataset info

In [5]:
print(f"Train data info:")
print(f"    Number of train videos: {np.array(train_data).shape[0]}")
print(f"    Min num frames in train videos: {min([data['video'].shape[2] for data in train_data])}")
print(f"    Max num frames in train videos: {max([data['video'].shape[2] for data in train_data])}")
print(f"    Average num of frames in train videos: {np.mean([data['video'].shape[2] for data in train_data])}")
# print([np.array(train_data[i]['box']).shape for i in range(len(train_data))])

print(f"Test data info:")
print(f"    Number of test videos: {np.array(test_data).shape[0]}")
print(f"    Min num frames in test videos: {min([data['video'].shape[2] for data in test_data])}")
print(f"    Max num frames in test videos: {max([data['video'].shape[2] for data in test_data])}")
print(f"    Average num of frames in test videos: {np.mean([data['video'].shape[2] for data in test_data])}")
print(f"Shapes: {[data['video'].shape for data in test_data]}")

Train data info:
    Number of train videos: 65
    Min num frames in train videos: 54
    Max num frames in train videos: 334
    Average num of frames in train videos: 151.83076923076922
Test data info:
    Number of test videos: 20
    Min num frames in test videos: 39
    Max num frames in test videos: 125
    Average num of frames in test videos: 75.35
Shapes: [(586, 821, 103), (587, 791, 52), (583, 777, 69), (582, 851, 61), (732, 845, 53), (583, 809, 84), (582, 737, 78), (587, 775, 125), (730, 956, 76), (587, 781, 104), (583, 681, 68), (587, 713, 90), (587, 612, 78), (587, 773, 73), (707, 855, 39), (731, 1007, 72), (583, 780, 106), (583, 670, 63), (594, 745, 51), (583, 779, 62)]


## Create videos for visualization

In [None]:
num_videos_to_save = 65

train = deepcopy(train_data)
for i in range(num_videos_to_save):
    video = train[i]['video']
    box = np.array(train[i]['box'])
    label = train[i]['label']
    label = label.astype('float32')
#     print(label)
#     label = label*255
#     print(list(label[train[i]['frames'][0]]))
    label = np.moveaxis(label, -1, 0)
    video = np.moveaxis(video, -1, 0)
    first_image = video[0]
    size = first_image.shape
    fps = 10
    out = cv2.VideoWriter(f'/cluster/project/infk/cvg/students/alpaul/MitralValveSegmentation/data/train_videos/{i}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (size[1], size[0]), False)
    for j in range(len(video)):
        image = video[j]
#         image[box] = 255
        labelj = label[j]
        image[labelj] = 255
        box = box.astype('uint8')*255
        label = label.astype('uint8')*255
        out.write(image)
    out.release()

# Generate videos or submission array

## Load data

In [5]:
experiment = 'exp2_stage2_ft'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
conf = '/cluster/home/alpaul/MitralValveSegmentation/mvseg/mvseg/configs/config_train.yaml'
logger.info(f'Starting test {experiment}')
output_dir = Path(TRAINING_PATH, experiment)
conf = OmegaConf.merge(OmegaConf.load(conf), {'train':{'num_workers': 0}})
data_conf = conf.data
dataset = get_dataset(data_conf.name)(data_conf)
test_loader = dataset.get_data_loader('test')
model = load_experiment(experiment, conf.model)
model = model.eval()
loss_fn, metrics_fn = model.loss, model.metrics

[12/18/2022 12:22:58 mvseg INFO] Starting test exp2_stage2_ft
[12/18/2022 12:22:58 mvseg.mvseg.datasets.base_dataset INFO] Creating dataset MVSegDataset
  cpuset_checked))
[12/18/2022 12:23:18 mvseg.mvseg.utils.experiments INFO] Loading checkpoint checkpoint_best.tar


## Loop through dataloader and write video. Optionally, store all predicted masks in array for submission

In [None]:
count = 0
threshold = 0.3
all_preds = {}
all_images = {}
fps = 10

submit = False # Set false if you want to generate video instead

if not submit:
    out = cv2.VideoWriter(f'{experiment}_test_full_t{threshold}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224), False) # compatible 821, 586

for data in tqdm(test_loader, desc='Testing', ascii=True, disable=False):
    count += 1
    with torch.no_grad():   
        pred = model(data)
        mask_pred = pred['seg'].squeeze(0).squeeze(0)
        h,w = data['hw']
        if submit:
            mask_pred = torch.tensor(resize(mask_pred, (h[0],w[0]), anti_aliasing=True))
        mask_pred = torch.gt(mask_pred, threshold)
        mask_pred = mask_pred.numpy().astype('bool')
        if not submit:
            im = data['image'].squeeze(0).squeeze(0).numpy()
#             im = resize(im, (h[0],w[0]), anti_aliasing=True)
            assert im.shape == mask_pred.shape
            im = im * 255
            im[mask_pred] = 255
            im = im.astype('uint8')       
            out.write(im)
        if submit:
            if data['video'][0] not in all_preds.keys():
                all_preds[data['video'][0]] = [mask_pred]
            else:
                all_preds[data['video'][0]].append(mask_pred)

if not submit:
    out.release()

## Submit

In [13]:
predictions = []
for d in test_data:
    prediction = np.array([list(p) for p in all_preds[d['name']]])
    prediction = np.moveaxis(prediction, 0, -1)
    assert prediction.shape == d['video'].shape
    predictions.append({
        'name': d['name'],
        'prediction': prediction
        }
    )

In [None]:
# save in correct format
save_zipped_pickle(predictions, f'my_predictions_{threshold}_retrain.pkl')

## Ignore this block: Loop through test dataset, and generate 1 video per test video

In [None]:
# output_directory = '/cluster/home/alpaul/videos'
# threshold = 0.7
# fps = 10
# predictions = []

# for d in tqdm(test_data):

#     d_video_frames = np.moveaxis(d['video'], -1, 0)
#     print(d_video_frames.shape)
#     height, width = d_video_frames.shape[1], d_video_frames.shape[2]
#     out = cv2.VideoWriter(f'{output_directory}/{d["name"]}_{threshold}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height), False)
#     prediction = []
#     count = 0
#     with torch.no_grad():
#         for idx in tqdm(range(len(d_video_frames))):
#             count += 1
#             d_im = list(filter(lambda p : p['frame_number'] == idx and p['video'][0] == d['name'], test_loader))[0]
#             pred = model(d_im)
#             mask_pred = pred['seg'].squeeze(0).squeeze(0)
#             h,w = d_im['hw'] # height to resize to
#             mask_pred = torch.tensor(resize(mask_pred, (h[0],w[0]), anti_aliasing=True))
#             mask_pred = torch.gt(mask_pred, threshold)
#             mask_pred = mask_pred.numpy().astype('bool')
#             im = d_im['image'].squeeze(0).squeeze(0)
#             im = resize(im, (h[0],w[0]), anti_aliasing=True)
#             assert im.shape == mask_pred.shape
#             im = im * 255
#             im[mask_pred] = 255
#             im = im.astype('uint8')
#             prediction.append(mask_pred)
#             out.write(im.astype('uint8'))
# #             plt.imshow(mask_pred.astype('float32')*255)
# #             print(np.max(mask_pred.astype('uint8')*255))
#     out.release()
#     del out
#     print(f"shape of prediction list: {np.array(prediction).shape}")
#     prediction = np.moveaxis(prediction, 0, -1)
#     print(f"shape of prediction list: {prediction.shape}")
#     predictions.append({
#         'name': d['name'],
#         'prediction': prediction
#         }
#     )
    