In [1]:
# Need to separate all files into function defintions and main.py part
from models import *

import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


In [2]:
# Auxillary functions:

def get_localisation_mask(original_mask):
    # This function converts the mask values from {0,1,...48} to {0,1} for background vs Object
    new_mask = (original_mask != 0)*1
    return new_mask

def get_color_ratios(original_image):
    # This function takes image (3,height, width) and return image with R/G, G/B, B/R ratios as 3 additional channels
    r_by_g = original_image[0]/original_image[1]
    g_by_b = original_image[1]/original_image[2]
    b_by_r = original_image[2]/original_image[0]
    
    all_channels = [original_image[0], original_image[1], original_image[2], r_by_g, g_by_b, b_by_r]
    
    combined_image = torch.stack(all_channels, dim = 0)
    return combined_image



In [3]:
# Data Loader Definitions

"""
class CustomDataset(Dataset):
    def __init__(self, all_frames, evaluation_mode = False):
        self.frames = torch.tensor(all_frames)
        self.evaluation_mode = evaluation_mode
#         self.masks = all_masks.cuda()

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        global net_id
        i,j = self.frames[idx]
        mode = 'val' if self.evaluation_mode else 'train'
        file_path = f"./../../../scratch/{net_id}/dataset_videos/dataset/{mode}/video_{i}/image_{j}.png"
        frame = torch.tensor(plt.imread(file_path)).permute(2, 0, 1)

        file_path = f"./../../../scratch/{net_id}/dataset_videos/dataset/{mode}/video_{i}/mask.npy"
        mask = np.load(file_path)[j]
        return frame, mask
"""

'\nclass CustomDataset(Dataset):\n    def __init__(self, all_frames, evaluation_mode = False):\n        self.frames = torch.tensor(all_frames)\n        self.evaluation_mode = evaluation_mode\n#         self.masks = all_masks.cuda()\n\n    def __len__(self):\n        return len(self.frames)\n\n    def __getitem__(self, idx):\n        global net_id\n        i,j = self.frames[idx]\n        mode = \'val\' if self.evaluation_mode else \'train\'\n        file_path = f"./../../../scratch/{net_id}/dataset_videos/dataset/{mode}/video_{i}/image_{j}.png"\n        frame = torch.tensor(plt.imread(file_path)).permute(2, 0, 1)\n\n        file_path = f"./../../../scratch/{net_id}/dataset_videos/dataset/{mode}/video_{i}/mask.npy"\n        mask = np.load(file_path)[j]\n        return frame, mask\n'

In [4]:
os.listdir('./../Dataset_Student/')

['.DS_Store', 'hidden', 'train', 'unlabeled', 'val']

In [5]:
# Data Loader.
class CustomDataset(Dataset):
    def __init__(self, num_of_vids=1000, evaluation_mode=False):
        self.evaluation_mode = evaluation_mode
        if self.evaluation_mode:
            self.mode = 'hidden'
            start_num = 15000
        else:
            self.mode = 'train'
            start_num = 0
        self.vid_indexes = torch.tensor([i for i in range(start_num, num_of_vids + start_num)])
        self.num_of_vids = num_of_vids
        

    def __len__(self):
        return self.num_of_vids
    
    def __getitem__(self, idx):
        num_hidden_frames = 11
        num_total_frames = 22
        x = []
        i = self.vid_indexes[idx]
        
        base_dir = './../Dataset_Student/'
        
        filepath = f'{base_dir}{self.mode}/video_{i}/'
        # obtain x values.
        for j in range(num_hidden_frames):
            x.append(torch.tensor(plt.imread(filepath + f'image_{j}.png')).permute(2, 0, 1))
        x = torch.stack(x, 0)
        
        if self.evaluation_mode:
            return x
        
        file_path = f"{base_dir}{self.mode}/video_{i}/mask.npy"
        y = np.load(file_path)[21]  # last frame.
        return x, y

In [6]:
# Dataloader
batch_size = 1

# Create Train DataLoader
num_videos = 2
train_data = CustomDataset(num_videos)
# load the data.
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Create Val DataLoader
num_val_videos = 1
val_data = CustomDataset(num_val_videos, evaluation_mode = True)
# load the data.
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

# batch_size = 8
# num_videos = 1000
# # num_val_videos = 1000

# train_data = CreateDatasetCustom(num_videos)
# train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# val_data = CreateDatasetCustom(num_val_videos, evaluation_mode=True)
# val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

In [7]:
for x,y in train_loader:
    print(x.shape, y.shape)
    
for x in val_loader:
    print(x.shape)

torch.Size([1, 11, 3, 160, 240]) torch.Size([1, 160, 240])
torch.Size([1, 11, 3, 160, 240]) torch.Size([1, 160, 240])
torch.Size([1, 11, 3, 160, 240])


In [8]:
# gpu_name = 'cuda'
# device = torch.device(gpu_name if torch.cuda.is_available() else 'cpu')

gpu_name = 'mps'
device = torch.device(gpu_name if torch.backends.mps.is_available() else 'cpu')

device = torch.device('cpu')
print(device)

cpu


In [9]:
class combined_model(nn.Module):
    def __init__(self, device):
        super(combined_model, self).__init__()
        self.frame_prediction_model = DLModelVideoPrediction((11, 3, 160, 240), 64, 512, groups=4)
        self.frame_prediction_model = nn.DataParallel(self.frame_prediction_model)
        self.frame_prediction_model = self.frame_prediction_model.to(device)

        self.image_segmentation_model = unet_model()
        self.image_segmentation_model = nn.DataParallel(self.image_segmentation_model)
        self.image_segmentation_model = self.image_segmentation_model.to(device)
        
    def load_weights(self):
        best_model_path = './checkpoints/frame_prediction.pth'  # load saved model to restart from previous best model
        if os.path.isfile(best_model_path):
            print('frame prediction model weights found')
            self.frame_prediction_model.load_state_dict(torch.load(best_model_path))

        best_model_path = './checkpoints/image_segmentation.pth'  # load saved model to restart from previous best model
        if os.path.isfile(best_model_path):
            print('image segmentation model weights found')
            self.image_segmentation_model.load_state_dict(torch.load(best_model_path))
            
    def save_weights(self):
        torch.save(self.frame_prediction_model.state_dict(), './checkpoints/frame_prediction.pth')
        torch.save(self.frame_prediction_model.state_dict(), './checkpoints/image_segmentation.pth')
        print('model weights saved successfully')
        
        
    def forward(self,x):
        x = self.frame_prediction_model(x)
#         print(x.shape)
        x = x[:,-1]
#         print(x.shape)
        x = self.image_segmentation_model(x)
#         print(x.shape)
        return x


In [10]:
# criterion = nn.CrossEntropyLoss()
# for x,y in train_loader:
#     x = x.to(device)
#     out = model(x)
#     print(out.shape, y.shape)
#     print(criterion(out, y.to(device).long()))
#     break

In [11]:

# # Instantiate frame_prediction model and segmentation_mask model
# frame_prediction_model = DLModelVideoPrediction((11, 3, 160, 240), 64, 512, groups=4)
# frame_prediction_model = nn.DataParallel(frame_prediction_model)
# frame_prediction_model = frame_prediction_model.to(device)

# best_model_path = './checkpoint_frame_prediction.pth'  # load saved model to restart from previous best model
# if os.path.isfile(best_model_path):
#     frame_prediction_model.load_state_dict(torch.load(best_model_path))

# # Instantiate frame_prediction model and segmentation_mask model
# # image_segmentation_model = FCN(49)
# image_segmentation_model = unet_model()
# image_segmentation_model = nn.DataParallel(image_segmentation_model)
# image_segmentation_model = image_segmentation_model.to(device)

# # best_model_path = 'fcn_model.pth'
# best_model_path = './image_segmentation.pth'  # load saved model to restart from previous best model
# if os.path.isfile(best_model_path):
#     image_segmentation_model.load_state_dict(torch.load(best_model_path))


In [12]:
# Hyperparameters:
num_epochs = 1
lr = 0.0001
model = combined_model(device)
model.load_weights()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_loader),
                                                epochs=num_epochs)



In [13]:
# Training Loop

# FLOW:
# get 11 frames of video from dataloader (optional: Data Augmentation)
# pass it through model to get prediction for 22nd frame
# pass prediction through segmentation model

In [16]:
train_losses = []
preds_per_epoch = []
for epoch in range(num_epochs):
    train_loss = []
    model.train()
    train_pbar = tqdm(train_loader)

    for batch_x, batch_y in train_pbar:
        optimizer.zero_grad()
        batch_x, batch_y = batch_x.to(device), batch_y.to(device).long()
        pred_y = model(batch_x)#.long()
        loss = criterion(pred_y, batch_y)
        train_loss.append(loss.item())
        train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
#         print(loss)
        loss.backward()
        optimizer.step()
        scheduler.step()

    train_loss = np.average(train_loss)
    print(f"Average train loss {train_loss}")
    train_losses.append(train_loss)
#     torch.save(model.state_dict(), './checkpoint_frame_prediction.pth')
    model.save_weights()
    val_loss = []
    model.eval()
    val_pbar = tqdm(val_loader)

    with torch.no_grad():
        if epoch % 2 == 0:
            for batch_x in val_pbar:
                batch_x = batch_x.to(device)
                pred_y = model(batch_x).float()
                preds_per_epoch.append(pred_y)


train loss: 3.6833: 100%|█████████████████████████| 2/2 [00:19<00:00,  9.73s/it]


Average train loss 3.7138378620147705
model weights saved successfully


100%|█████████████████████████████████████████████| 1/1 [00:02<00:00,  2.54s/it]


In [19]:
latest_predictions = preds_per_epoch[-1]
torch.save(latest_predictions, 'The_Big_Epochalypse_submission.pt')

torch.Size([1, 49, 160, 240])