In [3]:
# Need to separate all files into function defintions and main.py part
# from segmentation_mask import *
# from frame_prediction import *

import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


In [4]:
# Auxillary functions:

def get_localisation_mask(original_mask):
    # This function converts the mask values from {0,1,...48} to {0,1} for background vs Object
    new_mask = (original_mask != 0)*1
    return new_mask

def get_color_ratios(original_image):
    # This function takes image (3,height, width) and return image with R/G, G/B, B/R ratios as 3 additional channels
    r_by_g = original_image[0]/original_image[1]
    g_by_b = original_image[1]/original_image[2]
    b_by_r = original_image[2]/original_image[0]
    
    all_channels = [original_image[0], original_image[1], original_image[2], r_by_g, g_by_b, b_by_r]
    
    combined_image = torch.stack(all_channels, dim = 0)
    return combined_image



In [6]:
# Data Loader Definitions

"""
class CustomDataset(Dataset):
    def __init__(self, all_frames, evaluation_mode = False):
        self.frames = torch.tensor(all_frames)
        self.evaluation_mode = evaluation_mode
#         self.masks = all_masks.cuda()

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        global net_id
        i,j = self.frames[idx]
        mode = 'val' if self.evaluation_mode else 'train'
        file_path = f"./../../../scratch/{net_id}/dataset_videos/dataset/{mode}/video_{i}/image_{j}.png"
        frame = torch.tensor(plt.imread(file_path)).permute(2, 0, 1)

        file_path = f"./../../../scratch/{net_id}/dataset_videos/dataset/{mode}/video_{i}/mask.npy"
        mask = np.load(file_path)[j]
        return frame, mask
"""

In [None]:
# Data Loader.
class CustomDataset(Dataset):
    def __init__(self, num_of_vids=1000, evaluation_mode=False):
        start_num = 0
        self.vid_indexes = torch.tensor([i for i in range(start_num, num_of_vids + start_num)])
        self.evaluation_mode = evaluation_mode

    def __getitem__(self, idx):
        num_hidden_frames = 11
        num_total_frames = 22
        x = []
        i = self.vid_indexes[idx]
        mode = 'train'
        filepath = f'../dataset/{mode}/video_{i}/'
        # obtain x values.
        for j in range(num_hidden_frames):
            x.append(torch.tensor(plt.imread(filepath + f'image_{j}.png')).permute(2, 0, 1))
        x = torch.stack(x, 0)
        file_path = f"../dataset/train/video_{i}/mask.npy"
        y = np.load(file_path)[21]  # last frame.
        return x, y

In [2]:
# Dataloader
num_videos = 1000
num_frames_per_video = 22

all_frames = [[[i,j] for j in range(num_frames_per_video)] for i in range(num_videos)]
t = []
for i in all_frames:
    t += i
all_frames = torch.tensor(t)

batch_size = 8
num_videos = 3
# Create DataLoader
# # # train_dataset = CreateDatasetCustom(5)
train_data = CreateDatasetCustom(num_videos)
# load the data.
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Create DataLoader
num_val_videos = 1
val_data = CreateDatasetCustom(num_val_videos, evaluation_mode = True)
# load the data.
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

batch_size = 8
num_videos = 1000
# num_val_videos = 1000

train_data = CreateDatasetCustom(num_videos)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = CreateDatasetCustom(num_val_videos, evaluation_mode=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

In [3]:
# Instantiate frame_prediction model and segmentation_mask model
model = DLModelVideoPrediction((11, 3, 160, 240), 64, 512, groups=4)
model = nn.DataParallel(model)
model = model.to(device)

# Training Loop:
best_model_path = './checkpoint_frame_prediction.pth'  # load saved model to restart from previous best model
if os.path.isfile(best_model_path):
    model.load_state_dict(torch.load(best_model_path))


In [4]:
# Load trained model weights

In [None]:
# Training Loop

# FLOW:
# get 11 frames of video from dataloader (optional: Data Augmentation)
# pass it through model to get prediction for 22nd frame
# pass prediction through segmentation model