In [1]:
from __future__ import print_function, division
import sys
sys.path.append("../")

from dsbtorch import PreEmbeddedDataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pathlib
import pickle

import tqdm

from CNN_RNN import *
from torch.utils.tensorboard import SummaryWriter

In [2]:
plt.ion()   # interactive mode

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

data_dir = pathlib.Path("/home/ubuntu/data/encoded_dataset_511/")

dataset_names = ['train', 'dev'] #, 'test']
datasets = {x: PreEmbeddedDataset(data_dir / x) for x in dataset_names}

In [3]:
def collate_fn(batches):
    return list(batch[0] for batch in batches), list(batch[1] for batch in batches)

batch_size = 64
dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size, num_workers=6, pin_memory=True, collate_fn=collate_fn) for x in dataset_names}
dataset_sizes = {x: len(datasets[x]) for x in dataset_names}

In [13]:
def compute_IOU_from_indices(pred_starts, pred_ends, label_starts, label_ends):
    pred_starts, pred_ends, label_starts, label_ends = tuple(x.cpu().numpy() for x in (pred_starts, pred_ends, label_starts, label_ends))
    
    pred_ends += 1
    label_ends += 1
    intersection = np.maximum(0.0, np.minimum(pred_ends, label_ends) - np.maximum(pred_starts, label_starts))
    union = (pred_ends - pred_starts) + (label_ends - label_starts) - intersection
    return np.divide(intersection, union, out=np.zeros_like(intersection), where=union!=0)

In [6]:
def get_start_and_end_labels(labels):
    diffs = labels[..., 1:] - labels[..., :-1]

    # 1's for frames where the start of a sponsored segment occurs
    start_labels = torch.cat((torch.unsqueeze(labels[..., 0], dim=-1), diffs), dim=-1) == 1
    start_labels = start_labels.long()

    # 1's for frames where the end of a sponsored segment occurs
    end_labels = torch.cat((diffs, -torch.unsqueeze(labels[..., -1], dim=-1)), dim=-1) == -1
    end_labels = end_labels.long()                

    return start_labels.to(device), end_labels.to(device)

In [7]:
def train_model(rnn_decoder, criterion, optimizer, scheduler, output_path, num_epochs=25):
    writer = SummaryWriter()

    since = time.time()

    best_decoder_wts = copy.deepcopy(rnn_decoder.state_dict())
    best_iou = 0.0

    for epoch in range(num_epochs):
        print('\n\nEpoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        epoch_loss = {}
        epoch_iou = {}

        # Each epoch has a training and validation phase
        for phase in ['train', 'dev']:
            if phase == 'train':
                rnn_decoder.train()
                print('Training for one epoch.')
                print('-' * 8)
            else:
                rnn_decoder.eval()
                print('Evaluating model.')
                print('-' * 8)

            running_loss = 0.0
            total_iou = 0

            i = 0
            
            for cnn_outputs, labels in tqdm.tqdm(dataloaders[phase]):
                cnn_outputs = nn.utils.rnn.pack_sequence(cnn_outputs, enforce_sorted=False).to(device)
                
                # Get the RNN labels
                padded_labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
                padded_start_labels, padded_end_labels = get_start_and_end_labels(padded_labels)
                
                start_idxs = torch.argmax(padded_start_labels, dim=-1)
                end_idxs = torch.argmax(padded_end_labels, dim=-1)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    start_probs_ps, end_probs_ps = rnn_decoder(cnn_outputs)

                    start_probs_all, start_lengths = torch.nn.utils.rnn.pad_packed_sequence(start_probs_ps, batch_first=True)
                    end_probs_all, end_lengths = torch.nn.utils.rnn.pad_packed_sequence(end_probs_ps, batch_first=True)

                    start_probs_all = torch.squeeze(start_probs_all, dim=-1)
                    end_probs_all = torch.squeeze(end_probs_all, dim=-1)
                    
                    assert torch.all(start_lengths == end_lengths)
                    
                    loss = criterion(start_probs_all, start_idxs) + criterion(end_probs_all, end_idxs)
                                       
                    start_preds_idx = torch.argmax(start_probs_all, dim=-1)
                    end_preds_idx = torch.argmax(end_probs_all, dim=-1)
                    batch_iou = np.sum(compute_IOU_from_indices(start_preds_idx, end_preds_idx, start_idxs, end_idxs))

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # statistics
                    batch_loss = loss.item()
                    running_loss += batch_loss
                    total_iou += batch_iou

                if phase == 'train':
                    num_batches = len(dataloaders[phase])
                    writer.add_scalar("Batch Loss/" + phase, batch_loss, epoch * num_batches + i)
                    writer.add_scalar("Batch IOU/" + phase, batch_iou / len(labels), epoch * num_batches + i)
            
                i += 1

            if phase == 'train':
                scheduler.step()

            epoch_loss[phase] = running_loss / dataset_sizes[phase]
            epoch_iou[phase] = total_iou / dataset_sizes[phase]
            
            writer.add_scalar("Loss/" + phase, epoch_loss[phase], epoch)
            writer.add_scalar("IOU/" + phase, epoch_iou[phase], epoch)

            print('{} Loss: {:.4f} IOU: {:.4f}'.format(
                phase, epoch_loss[phase], epoch_iou[phase]))

            torch.save(rnn_decoder.state_dict(), output_path + str(epoch))
            
            # deep copy the model
            if phase == 'dev':
                best_iou = epoch_iou[phase]
                best_decoder_wts = copy.deepcopy(rnn_decoder.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val F score: {:4f}'.format(best_iou))

    # Save and load best model weights
    rnn_decoder.load_state_dict(best_decoder_wts)
    torch.save(rnn_decoder.state_dict(), output_path)
    return rnn_decoder

In [8]:
decoder = nn.Sequential(
    Embedder(2048),
    DecoderRNN(sigmoid=False)
)

In [9]:
decoder = decoder.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = optim.Adam(decoder.parameters(), lr=0.001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.01)

In [10]:
# decoder.load_state_dict(torch.load("/home/ubuntu/data/DeepSponsorBlock/results/rnn.weights.decoder"))

In [11]:
model = train_model(decoder, criterion, optimizer, exp_lr_scheduler, '../results/rnn_full_ce.weights', num_epochs=21)

  0%|          | 0/235 [00:00<?, ?it/s]



Epoch 0/20
----------
Training for one epoch.
--------



  0%|          | 0/16 [00:00<?, ?it/s][A

train Loss: 0.0011 IOU: 0.0000
Evaluating model.
--------




  0%|          | 0/235 [00:00<?, ?it/s][A[A

dev Loss: 0.0194 IOU: 0.0037


Epoch 1/20
----------
Training for one epoch.
--------





  0%|          | 0/16 [00:00<?, ?it/s][A[A[A

train Loss: 0.0011 IOU: 0.0005
Evaluating model.
--------


KeyboardInterrupt: 