In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from os.path import expanduser, join, exists
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from data.oflow_fp_dl import OflowDlDataset

import sys
sys.path.insert(0, os.path.expanduser('~/Code/cutils'))  
from path_util import mkdir_p, save_pkl

from datetime import datetime
import pytz


# In[3]:


#GPU 
cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
#num_gpus = 1

root_path = expanduser('~/Data/tremor_hand_of/fp_uv_segments/')
pkl_root = join(root_path, 'segment_16frame')
data_info_path = join(root_path, 'freq_label_dic.json')

test_pids = ['P005']
val_pids =['P024', 'P025', 'P026', 'P029']

it_tolerance = 200
start_it = 0

test_str = '_'.join(sorted(test_pids))
val_str = '_'.join(val_pids)

train_fpath = join(root_path, 'dl_amp_5_pos_1-29pts-Test_{}-Val_{}_train.pkl'.format(test_str, val_str))
val_fpath = join(root_path, 'dl_amp_5_pos_1-29pts-Test_{}-Val_{}_val.pkl'.format(test_str, val_str))
test_fpath = join(root_path, 'dl_amp_5_pos_1-29pts-Test_{}-Val_{}_test.pkl'.format(test_str, val_str))

model_name = 'fp_z10_classifyOnly_29pts_3boosts_Test_{}_Val_{}'.format(test_str, val_str)
model_path = expanduser('~/myData/aae/model_{}'.format(model_name))
mkdir_p(model_path)

cont_model_fpath = join(model_path, "{}_{}_".format(model_name, start_it) + "{}_latest.pt")

init_model_fpath = expanduser('~/Data/tremor_hand_pretrained_models/init_models/init_z10_cls_9freq_200_Q_latest.pt') 
gen_model_fpath = expanduser('~/Data/tremor_hand_pretrained_models/generator/gen_fp_z10_excl_alltest_1700_P_latest.pt')

out_dir = expanduser('~/myData/aae/out_aae_{}'.format(model_name))
mkdir_p(out_dir)


# In[4]:


discard_list = ['freq_010-P006-vid_001', 
                'P009', 
                'freq_004-P011-vid_001', 
                'freq_004-P013', 
                
                'freq_010-P008-vid_002', 
                'freq_004-P015-vid_002', 
                'freq_004-P015-vid_002', 
                'freq_010-P015-vid_001', 
                'freq_010-P016-vid_001', 
                'freq_010-P017-vid_001', 
                'freq_010-P020-vid_001', 
                'freq_010-P020-vid_002', 
                'freq_010-P022-vid_001', 
                'freq_010-P022-vid_002', 
                'freq_010-P023-vid_001', 
                
                'uvseg_amp_5-pos_1-freq_004-P016-vid_002', #out of FOV
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002161-002225',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002177-002241',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002193-002257',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002209-002273',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002225-002289',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002241-002305',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002257-002321',
                'uvseg_amp_5-pos_1-freq_010-P016-vid_001_frame_002273-002337',
                'uvseg_amp_5-pos_1-freq_010-P029-vid_001', # too dark to see motion
                'uvseg_amp_5-pos_1-freq_010-P029-vid_002'
               ]

nz = 10 # size of latent vector --> action type. TODO: 2d enough?
slr_fac = 0.03

cuda = True
lr = 1e-4
batch_size = 64 #8*num_gpus # number of segments in each batch
in_size = 64

fps = 30
ns = 2 # time series length (in seconds)
nfrm = ns * fps + 4 #time series length (in number of frames) #also make it 64
nc = 2 # number of channels (each frm in series has theta & rad)
ngf = 64 # decoder (generator) filter factor
ndf = 64 # encoder filter factor
h_dim = 128 # discriminator hidden size
lam = 1 # regulization coefficient
n_labels = 3 # total number of freq labels


# In[5]:


class Decoder3D_yz(nn.Module):
    def __init__(self):
        super(Decoder3D_yz, self).__init__()

        # state size. (nc) x 64 x 64, from y & z separately
        self.convt_y = nn.ConvTranspose3d(n_labels, ngf*8, kernel_size=4, stride=1, padding=0, bias=False)
        self.convt_z = nn.ConvTranspose3d(nz, ngf*8, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm3d(ngf * 8)
        # state size. (ngf*8) x 4 x 4
        self.convt2 = nn.ConvTranspose3d(ngf*8, ngf*4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(ngf * 4)
        # state size. (ngf*4) x 8 x 8
        self.convt3 = nn.ConvTranspose3d(ngf * 4, ngf*2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(ngf * 2)
        # state size. (ngf*2) x 16 x 16
        self.convt4 = nn.ConvTranspose3d(ngf*2,  ngf, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm3d(ngf)
        # state size. (ngf) x 32 x 32
        self.convt5 = nn.ConvTranspose3d(ngf, nc, kernel_size=4, stride=2, padding=1, bias=False)
        
    def forward(self, y, z):
        x_from_y = self.convt_y(y)
        x_from_z = self.convt_z(z)
        x = x_from_y + x_from_z
        x = self.bn1(x)
        x = F.relu(x, True)
        x = self.convt2(x)
        x = self.bn2(x)
        x = F.relu(x, True)
        x = self.convt3(x)
        x = self.bn3(x)
        x = F.relu(x, True)
        x = self.convt4(x)
        x = self.bn4(x)
        x = F.relu(x, True)
        x = self.convt5(x)
        x = F.sigmoid(x)
        return x


# In[6]:


class Encoder3D_yz(nn.Module):
    def __init__(self):
        super(Encoder3D_yz, self).__init__()
        # input is (nc=2 [u,v]) nfrm x 64 x 64 
        self.conv1 = nn.Conv3d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=False)
        # state size. (ndf) x 32 x 32
        self.conv2 = nn.Conv3d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(ndf * 2)
        # state size. (ndf*2) x 16 x 16
        self.conv3 = nn.Conv3d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(ndf * 4)
        # state size. (ndf*4) x 8 x 8
        self.conv4 = nn.Conv3d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm3d(ndf * 8)
        # separate y & z
        self.conv_y = nn.Conv3d(ndf * 8, n_labels, kernel_size=4, stride=1, padding=0, bias=False)
        self.softmax_y = nn.Softmax()
                           
    def forward(self, x, print_size = False):
        if print_size:
            print('\t\tIn Model: input size {}'.format(x.size()))    
        x = self.conv1(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv4(x)
        x = self.bn4(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        y = self.conv_y(x)
        y = F.softmax(y)
        if print_size:
            print('\t\tIn Model: output y size {}'.format(y.size()))
        return y


# In[7]:


def plot_and_save_segment_images(samples, out_dir, fig_title, convert=False):    
    ncols = 8
    nrows = int(np.ceil(samples.shape[1]/8.)) 
    sub_plot_sz = 3.5
    fig = plt.figure(figsize=(sub_plot_sz*ncols, sub_plot_sz*nrows))
    gs = gridspec.GridSpec(nrows, ncols)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i in range(8): #range(samples.shape[1]):  
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        sub_fig_title = fig_title 
        ax.set_title(sub_fig_title)
        
        img = np.zeros((samples.shape[2], samples.shape[3], 3), dtype=np.float32)
        img[:,:,1] = samples[0, i, :, :]
        img[:,:,2] = samples[1, i, :, :]
        
        plt.imshow(img)
        
    if not os.path.exists(out_dir):
            os.makedirs(out_dir)
    plt.savefig('{}/{}.png'.format(out_dir, fig_title), bbox_inches='tight')
    plt.show()
    plt.close(fig)


# In[8]:


def plot_and_save_batch_images(samples, out_dir, fig_title, y = None, convert=False):
    if convert:
        if cuda:
            samples = samples.cpu()
        samples = samples.data.numpy() 
        
    for seg_idx in range(1): #range(samples.shape[0]):
        if y is not None:
            fig_title += '_y:{}'.format(y[seg_idx])
        plot_and_save_segment_images(samples[seg_idx], out_dir, fig_title) #'{}_s{}'.format(fig_title, seg_idx))


# In[9]:


def plot_losses_and_save_fig(S_losses, train_accuracies, val_accuracies, test_accuracies, out_dir, fig_title_prefix):
    save_pkl({'test_acc': test_accuracies, 
              'val_acc': val_accuracies,
              'train_acc': train_accuracies, 
              'S': S_losses
             }, 
             '{}/{}_loss_data.pkl'.format(out_dir, fig_title_prefix))

    plt.figure(figsize=(12, 36))
    
    ax = plt.subplot(4,1,1)
    ax.set_title('{}_Supervise_loss'.format(fig_title_prefix))
    plt.plot(S_losses)
    
    ax = plt.subplot(4,1,2)
    ax.set_title('{}_Train_Accuracy'.format(fig_title_prefix))
    plt.plot(train_accuracies)

    ax = plt.subplot(4,1,3)
    ax.set_title('{}_Val_Accuracy'.format(fig_title_prefix))
    plt.plot(val_accuracies)
    
    ax = plt.subplot(4,1,4)
    ax.set_title('{}_Test_Accuracy'.format(fig_title_prefix))
    plt.plot(test_accuracies)

    plt.savefig('{}/{}_loss.png'.format(out_dir, fig_title_prefix), bbox_inches='tight')
    plt.show()


# In[10]:


def calculate_accuracy_batches(data_loader, Q):
    tp = 0  
    tot = 0
    for batch_idx, (data, batch_Y, y_onehot, _, _) in enumerate(data_loader):
        curr_batch_size = len(batch_Y)
        X = data.to(device)              
        y_onehot = y_onehot.to(device)
            
        y_onehot_ = Q(X)
        
        y_ = np.argmax(y_onehot_.cpu().detach().view(curr_batch_size,-1).numpy(), axis=1)
        batch_Y = batch_Y.cpu().data.numpy() 
#         print('y_onehot = {}'.format(y_onehot.cpu().detach().view(curr_batch_size,-1).numpy()))
#         print('y_onehot_ = {}'.format(y_onehot_.cpu().detach().view(curr_batch_size,-1).numpy()))
#         print('type(y_){}   y_:{}'.format(type(y_), y_))
#         print('type(y){}   y :{}'.format(type(batch_Y), batch_Y))
#        print('current tp = {}'.format(np.sum(y_ == batch_Y)))
        tp += np.sum(y_ == batch_Y)
        tot += len(batch_Y)
    
#    print('tp = {}, tot_num = {}, acc = {}'.format(tp, tot, float(tp)/tot))
    return float(tp)/tot

def calculate_accuracy(train_loader, val_loader, test_loader, Q):
    train_acc = calculate_accuracy_batches(train_loader, Q)
    val_acc = calculate_accuracy_batches(val_loader, Q)
    test_acc = calculate_accuracy_batches(test_loader, Q)
    return train_acc, val_acc, test_acc


# In[11]:


print('Use {} GPUs'.format(num_gpus))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# In[12]:


import collections
Q = Encoder3D_yz()
trained_P = Decoder3D_yz()
   
trained_P = nn.DataParallel(trained_P)    
if gen_model_fpath is not None and exists(gen_model_fpath):
    print('loading pretrained generator ... {}'.format(gen_model_fpath))
    trained_P.load_state_dict(torch.load(gen_model_fpath))
else:
    print('WARNING: no pretrained generator loaded')
    
cont_model_qpath = cont_model_fpath.format('Q')
if exists(cont_model_qpath):
    print('loading continue model... {}'.format(cont_model_qpath))
    loaded_model = torch.load(cont_model_qpath)  
    print('loaded_model keys = {}'.format(loaded_model.keys()))
    if 'module.' in loaded_model.keys()[0]: #trained using DataParallel
        print('loading parallel model ...')
        Q = nn.DataParallel(Q)                      
    Q.load_state_dict(loaded_model)
else:
    if init_model_fpath is not None and exists(init_model_fpath):
        print('loading initial model... {}'.format(init_model_fpath))
        loaded_model = torch.load(init_model_fpath)
        if 'module.' in list(loaded_model.keys())[0]: #trained using DataParallel
            print('loading parallel model ...')
            Q = nn.DataParallel(Q)
            Q.load_state_dict(loaded_model)   
        else:
            Q.load_state_dict(loaded_model)   
            Q = nn.DataParallel(Q)
    else:
        print('WARNING: no initial model loaded')
    
Q = Q.to(device)
trained_P = trained_P.to(device)                                                  
   
def reset_grad():
    Q.zero_grad()
    
Qsup_solver = optim.Adam(Q.parameters(), lr=slr_fac*lr)

In [None]:
# In[13]:


def tensor_to_value(t_var):
    if cuda:
        t = t_var.cpu()
    else:
        t = t_var
        
    return t.data.numpy()


# In[14]:


S_losses = []
train_accuracies = []
val_accuracies = []
test_accuracies = []

est = pytz.timezone('US/Eastern')
print('Time Zone: {}'.format(est))
est_time = datetime.now(est)
print('default format: {}'.format(est_time))

train_loader = DataLoader(dataset=OflowDlDataset(dl_path=train_fpath,
                                                 data_info_path=data_info_path,
                                                 pkl_root = pkl_root,
                                                 discard_list = discard_list,
                                                 no_aug_ratio = 1.0,
                                                 seg_random_seed=123),
                          batch_size = batch_size, 
                          num_workers = num_gpus, 
                          shuffle = True)

train_loader_aug = DataLoader(dataset=OflowDlDataset(dl_path=train_fpath,
                                                 data_info_path=data_info_path,
                                                 pkl_root = pkl_root,
                                                 discard_list = discard_list,
                                                 no_aug_ratio = 0.0,
                                                 seg_random_seed=123),
                          batch_size = batch_size, 
                          num_workers = num_gpus, 
                          shuffle = True)

val_loader = DataLoader(dataset=OflowDlDataset(dl_path=val_fpath,
                                                data_info_path=data_info_path, 
                                                pkl_root = pkl_root
                                               ),
                          batch_size = batch_size, 
                          num_workers = num_gpus, 
                          shuffle = False)

test_loader = DataLoader(dataset=OflowDlDataset(dl_path=test_fpath,
                                                data_info_path=data_info_path, 
                                                pkl_root = pkl_root
                                               ),
                          batch_size = batch_size, 
                          num_workers = num_gpus, 
                          shuffle = False)


# In[15]:


def train_classify_network(X, y_onehot, convert = True):
    """
    train supervised classification with augmented data to avoid artifacts effect in Q,P
    """
    if convert:
        X = Variable(X)    
        X = X.to(device)
    
    y_onehot = y_onehot.to(device)
    
    curr_batch_size = X.size()[0]
    
    y_predict = Q(X)
    S_loss = F.binary_cross_entropy_with_logits(input=y_predict.view(curr_batch_size,-1), target=y_onehot)
    S_loss.backward()
    Qsup_solver.step()
    reset_grad()
    
    return S_loss


# In[ ]:


num_of_batches = len(train_loader)
print('batch len = {}'.format(num_of_batches))
print_size = False

max_val_acc = 0.0
it_sum = 0

for it in range(start_it, 100000, 1): 
    """
    Training: except regularization phase, all the other phases in training updates Q parameters 
    """
    # data augmentation
    if it > start_it and it % 2 == 1: #use real image first
        for batch_idx, (data, batch_Y, y_onehot, _, augs) in enumerate(train_loader_aug):
            S_loss = train_classify_network(data, y_onehot, convert = True)
            
            if it % 100 == 1:
                plot_and_save_batch_images(data, out_dir=out_dir, fig_title='{0}_sup_aug_y'.format(it), y=batch_Y)
    
    # generator 
    if it > start_it and it % 2 == 0: #use real image first
        for gen_i in range(num_of_batches-1):
            curr_batch_size = batch_size
            
            z_gen = Variable(torch.randn(curr_batch_size, nz)) #gaussian distribution 
            
            # randomly generate y: 0.3 freq000, 0.3 freq004, 0.4 freq010 since fewer freq010 real data
            y_gen = np.random.randint(low = 0, high = 10, size = curr_batch_size)
            y_gen = [0 if e <= 2 else 1 if e<=5 else 2 for e in list(y_gen)] 
            y_onehot_gen = np.zeros(shape = (curr_batch_size, n_labels), dtype=np.float32)
            y_onehot_gen[np.array(range(curr_batch_size)), np.array(y_gen)] = 1.
            y_onehot_gen = torch.from_numpy(y_onehot_gen)
        
            z_gen = z_gen.unsqueeze(2).unsqueeze(3).unsqueeze(4) # add 2 dimensions            
            z_gen = z_gen.to(device)
            
            X_gen = trained_P(y_onehot_gen.unsqueeze(2).unsqueeze(3).unsqueeze(4).to(device), z_gen)
            
            # supervised training using generated images
            S_loss = train_classify_network(X_gen, y_onehot_gen, convert = False)
            
            if it == 2 or it % 100 == 0:
                plot_and_save_batch_images(X_gen, out_dir=out_dir, fig_title='{0}_sup_syn_y'.format(it), y=y_gen, convert=True)    
                
    for batch_idx, (data, batch_Y, y_onehot, _, _) in enumerate(train_loader):
        curr_batch_size = data.size()[0]       
                
        if it == start_it and batch_idx == 0:
            print('\tOutside Model, input size {}'.format(data.size()))
            print_size = True   
                    
        # supervised classification phase
        S_loss = train_classify_network(data, y_onehot, convert = True)
        
        if it == start_it and batch_idx == 0:
            print_size = False   
         
        if batch_idx == num_of_batches-1:
            S_losses.append(tensor_to_value(S_loss))
            train_acc, val_acc, test_acc = calculate_accuracy(train_loader, val_loader, test_loader, Q)
            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            test_accuracies.append(test_acc)
   
        # Print and plot every now and then
        if it % 1 == 0 and batch_idx == num_of_batches-1:
            est_time = datetime.now(est)
            if val_acc > max_val_acc:
                max_val_acc = val_acc
                it_sum = 0
                
                # save model
                if it > it_tolerance: #not save too many models at the beginning
                    torch.save(Q.state_dict(), join(model_path, "{}_{}_Q_best.pt".format(model_name, it))) 
                    with open(join(out_dir, 'aae_iter_{}_acc.txt'.format(it)), 'w') as f:
                              f.write('train_acc: {}, val_acc: {}, test_acc: {}\n'.format(train_acc, val_acc, test_acc))
            else:
                if it_sum > it_tolerance:
                    break
                it_sum += 1
                
            print('Iter-{} Batch-{}: {}, S_loss: {:.3}; train_acc: {:.3}; val_acc: {:.3}; test_acc: {:.3}, it_sum={}'
              .format(it, batch_idx, est_time, S_loss.data[0], train_acc, val_acc, test_acc, it_sum))
                       
        if it > start_it and it % 100 == 0 and batch_idx == num_of_batches-1: #at the end of each iteration
            #plot loss curve
            plot_losses_and_save_fig(S_losses, train_accuracies, val_accuracies, test_accuracies, out_dir=out_dir, 
                                     fig_title_prefix='Plot_Iter_{}'.format(it)
                                    )
            
            plot_and_save_batch_images(data, out_dir=out_dir, fig_title='Img{0}_uv_y'.format(it), y=batch_Y)