In [1]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch.nn.utils import clip_grad_norm_
from time import time
from datetime import datetime
from model import MyModel
from data import MyDataset, Collates
import argparse
import torch.nn.functional as F
from torch_future import OneCycleLR, Nadam, SGDW
from torch import nn, optim
from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import confusion_matrix
from data import emotions
from collections import Counter

In [2]:
import importlib
import data
importlib.reload(data)
from data import MyDataset, Collates

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
collates = Collates()

train_path = '/usr/cs/public/mohd/data/train'
train_txt = '/usr/cs/public/mohd/train_data_labeled.txt'
train_dataset = MyDataset(train_path, train_txt, size=500)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
        collate_fn = collates.collate_fn, num_workers=0)

val_path = '/usr/cs/public/mohd/data/val'
val_txt = '/usr/cs/public/mohd/val_data_labeled.txt'
val_dataset = MyDataset(val_path, val_txt, mode='val', size=-1)
val_dataset.update_data()
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = collates.val_collate_fn, num_workers=0)

100%|██████████| 331/331 [01:56<00:00,  2.83it/s]


In [None]:
for i,item in enumerate(train_dataset):
    n = item[2].shape[0]
    if n < 1 :
        print(n, i, train_dataset.imgs['dbpGH5iP0GE'].shape)

In [None]:
val_dataset = MyDataset(val_path, val_txt, mode='val', size=2500)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = collates.val_collate_fn, num_workers=0)

In [None]:
import importlib
importlib.reload(torch.cuda)
torch.__version__

In [None]:
saved_train = (train_dataset.imgs, train_dataset.audios, train_dataset.texts, train_dataset.files)
saved_val = (val_dataset.imgs, val_dataset.audios, val_dataset.texts, val_dataset.files)

In [None]:
train_dataset.imgs = saved_train[0]
train_dataset.audios = saved_train[1]
train_dataset.texts = saved_train[2]
train_dataset.files = saved_train[3]
val_dataset.imgs = saved_val[0]
val_dataset.audios = saved_val[1]
val_dataset.texts = saved_val[2]
val_dataset.files = saved_val[3]
train_dataset.size = len(train_dataset.texts)
val_dataset.size = len(val_dataset.texts)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [4]:
def combine_avg(preds):
    if preds and len(preds) > 0:
        return np.expand_dims(np.mean(np.concatenate(preds),0),0)
    else:
        return []

def combine_voting(preds):
    preds = np.argmax(np.concatenate(preds), 1)
    ind = np.argmax(np.bincount(preds))
    return ind

mean = lambda l: sum(l)/len(l)

class EarlyStopping():
    def __init__(self, patience = 20, delta = 0.01):
        self.best_loss = None
        self.counter = 0
        self.patience = patience
        self.delta = delta
        
    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss + self.delta:
            self.counter += 1
        elif loss < self.best_loss:
            self.best_loss = loss
            self.counter = 0
        
        if self.counter >= self.patience:
            return True
        else:
            return False
        
def initiate_model(hparams, name, resume=None):
    import importlib
    import model
    import torch
    importlib.reload(torch)
    importlib.reload(model)
    from model import MyModel
    model = MyModel(name, hparams).to(device)

    if resume is not None:
      state = torch.load(os.path.join('/usr/cs/grad/doc/melgaar/multimodal-emo-rec/ckpt','%s_%d.pt'%resume))
    
      model.load_state_dict(state['model_state_dict'])
      model.name = state['name']
      model.iter = state['iter']
      model.stage = state['stage']
      writer = SummaryWriter(logdir='runs/' + model.name, purge_step = model.iter)
      print("Loaded model")
    else:
      writer = SummaryWriter(logdir='runs/' + name)
    
    return model, writer

def initiate_optimizers(hparams, model):
#     spk_params = list(model.classify_spk.parameters()) \
#         + list(model.classify_gen.parameters()) \
#         + list(model.classify_age.parameters()) \
#         + list(model.av_spk.parameters())
    audio_params = list(model.rnn_audio.parameters())
    img_params = list(model.rnn_img.parameters())
    text_params = list(model.rnn_text.parameters())
    av_params = list(model.av.parameters())
    at_params = list(model.at.parameters())
    vt_params = list(model.vt.parameters()) 
    int_params = list(model.avt.parameters()) + list(model.classify_integrated.parameters())

    if hparams.setting == 'aux':
        audio_params += list(model.classify_audio.parameters())
        img_params += list(model.classify_img.parameters())
        text_params += list(model.classify_text.parameters())
        av_params += list(model.classify_av.parameters())
        at_params += list(model.classify_at.parameters())
        vt_params += list(model.classify_vt.parameters())
    
    if hparams.tune_prev:
        int_params += (audio_params + img_params + text_params + av_params + at_params + vt_params)
        av_params += (audio_params + img_params)
        at_params += (audio_params + text_params)
        vt_params += (img_params + text_params)
    
    optims = {'adam':optim.Adam,
             'nadam':Nadam,
             'adamw': optim.AdamW,
             'sgd': lambda params, lr, weight_decay: optim.SGD(params,
                                                               lr,
                                                               weight_decay=weight_decay,
                                                               momentum=0.9, nesterov=True),
             'sgdw':lambda params, lr, weight_decay: SGDW(params,
                                                          lr,
                                                          weight_decay=weight_decay,
                                                          momentum=0.9, nesterov=True)}
        
    optimizers = {}
    optimizers['audio'] = optims[hparams.audio_optim](audio_params, lr = hparams.audio_lr, weight_decay = hparams.audio_reg)
    optimizers['img'] = optims[hparams.img_optim](img_params, lr = hparams.img_lr,weight_decay = hparams.img_reg)
    optimizers['text'] = optims[hparams.text_optim](text_params, lr = hparams.text_lr, weight_decay = hparams.text_reg)
    optimizers['av'] = optims[hparams.av_optim](av_params, lr = hparams.av_lr, weight_decay = hparams.av_reg)
    optimizers['at'] = optims[hparams.at_optim](at_params, lr = hparams.at_lr, weight_decay = hparams.at_reg)
    optimizers['vt'] = optims[hparams.vt_optim](vt_params, lr = hparams.vt_lr, weight_decay = hparams.vt_reg)
    optimizers['int'] = optims[hparams.int_optim](int_params, lr = hparams.int_lr, weight_decay = hparams.int_reg)
#     optimizers['spk'] = optims[hparams.spk_optim](spk_params, lr = hparams.spk_lr, weight_decay = hparams.spk_reg)
    return optimizers

def save_model(model):
    state = {'model_state_dict': model.state_dict(),
             'name': model.name,
             'iter': model.iter,
             'stage': model.stage
            }
    torch.save(state, os.path.join('/usr/cs/grad/doc/melgaar/multimodal-emo-rec/ckpt', "%s_%d.pt"%(model.name,
                                                                                                  model.iter)))

In [5]:
def acc_fn(pred, label):
    pred[pred > 0.5] = 1
    pred[pred <= 0.5] = 0
    pred = pred.detach().cpu()
    label = label.detach().cpu()
    acc = (pred == label).numpy().mean()
    
    f1s, waccs = [], []
    for emo in range(label.shape[1]):
        sublabels = label[:,emo]
        subpreds = pred[:,emo]
        f1 = f1_score(sublabels, subpreds, average='binary')
        counts = Counter(sublabels.tolist())
        n, p = counts[0], counts[1]
        tn, fp, fn, tp = confusion_matrix(sublabels, subpreds).ravel()
        wacc = (tp * n/p + tn) / (2*n)
        f1s.append(f1)
        waccs.append(wacc)
    

    return acc, mean(f1s), mean(waccs)


def train_step(model, optimizers, batch, writer):
    bs = batch[0].shape[0]

    if model.hparams.masking:
        #text
        mask = torch.ones((bs,1,1), device=batch[0].device)
        text_mask = F.normalize(F.dropout(mask, model.hparams.mask_rate))
        batch[1] *= text_mask

        #audio
        mask = torch.ones((bs,1,1), device=batch[0].device)
        audio_mask = F.normalize(F.dropout(mask, model.hparams.mask_rate))
        batch[0] *= audio_mask

        #imgs
        mask = torch.ones((bs,1,1), device=batch[0].device)
        imgs_mask = F.normalize(F.dropout(mask, model.hparams.mask_rate))
        batch[3] *= imgs_mask

    results = model(batch[:6])
    y_int = batch[6]

#     if masking and model.stage >= 2:
#         loss_audio = (model.loss(out_audio, y_audio)[audio_mask.squeeze() != 0]).mean()
#         loss_imgs = (model.loss(out_imgs, y_imgs)[imgs_mask.squeeze() != 0]).mean()
#         loss_text = (model.loss(out_text, y_int)[text_mask.squeeze() != 0]).mean()

    if model.stage == 1:
        if model.hparams.setting == 'aux':
            out_audio, out_imgs, out_text = results
#             loss_spk = model.hparams.alpha_spk * model.loss(out_spk, spk).mean()
#             loss_gen = model.hparams.alpha_gen * model.loss(out_gen, gen).mean()
#             loss_age = model.hparams.alpha_age * model.reg_loss(out_age, age).mean()
            loss_audio = model.loss(out_audio, y_int).mean()
            loss_imgs = model.loss(out_imgs, y_int).mean()
            loss_text = model.loss(out_text, y_int).mean()
            loss = (loss_audio + loss_imgs + loss_text)/3
#             loss += loss_spk + loss_gen + loss_age
        elif model.hparams.setting == 'full':
            out_int = results
            loss_int = model.loss(out_int, y_int).mean()
#             loss_spk = model.hparams.alpha_spk * model.loss(out_spk, spk).mean()
#             loss_gen = model.hparams.alpha_gen * model.loss(out_gen, gen).mean()
#             loss_age = model.hparams.alpha_age * model.reg_loss(out_age, age).mean()
            loss = loss_int
        
        optimizers['audio'].zero_grad()
        optimizers['img'].zero_grad()
        optimizers['text'].zero_grad()
#         optimizers['spk'].zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), model.hparams.clip)
        if not model.earlystop_audio:
            optimizers['audio'].step()
            if model.hparams.setting == 'aux':
                audio_acc, audio_f1, audio_wacc = acc_fn(out_audio, y_int)
                writer.add_scalar('train/stage1/loss/audio_loss', loss_audio, model.iter)
                writer.add_scalar('train/stage1/acc/audio_acc',audio_acc, model.iter)
        if not model.earlystop_img:
            optimizers['img'].step()
            if model.hparams.setting == 'aux':
                img_acc, img_f1, img_wacc = acc_fn(out_imgs, y_int)
                writer.add_scalar('train/stage1/loss/img_loss', loss_imgs, model.iter)
                writer.add_scalar('train/stage1/acc/img_acc',img_acc, model.iter)
        if not model.earlystop_text:
            optimizers['text'].step()
            if model.hparams.setting == 'aux':
                text_acc, text_f1, text_wacc  = acc_fn(out_text, y_int)
                writer.add_scalar('train/stage1/loss/text_loss', loss_text, model.iter)
                writer.add_scalar('train/stage1/acc/text_acc',text_acc, model.iter)
#         if not model.earlystop_spk:
#             optimizers['spk'].step()
#             writer.add_scalar('train/loss/spk/spk', loss_spk, model.iter)
#             writer.add_scalar('train/loss/spk/gen', loss_gen, model.iter)
#             writer.add_scalar('train/loss/spk/age', loss_age, model.iter)

        if model.hparams.setting == 'full':
            int_acc, int_f1, int_wacc = acc_fn(out_int, y_int)
            writer.add_scalar('train/loss/int_loss', loss_int, model.iter)
            writer.add_scalar('train/acc/int_acc',int_acc, model.iter)
        
    elif model.stage == 2:
        if model.hparams.setting == 'aux':
            out_av, out_at, out_vt = results
#             loss_spk = model.hparams.alpha_spk * model.loss(out_spk, spk).mean()
#             loss_gen = model.hparams.alpha_gen * model.loss(out_gen, gen).mean()
#             loss_age = model.hparams.alpha_age * model.reg_loss(out_age, age).mean()
            loss_av = model.loss(out_av, y_int).mean()
            loss_at = model.loss(out_at, y_int).mean()
            loss_vt = model.loss(out_vt, y_int).mean()
            loss = (loss_av + loss_at + loss_vt)/3
#             loss += loss_spk + loss_gen + loss_age
        elif model.hparams.setting == 'full':
            out_int = results
#             loss_spk = model.hparams.alpha_spk * model.loss(out_spk, spk).mean()
#             loss_gen = model.hparams.alpha_gen * model.loss(out_gen, gen).mean()
#             loss_age = model.hparams.alpha_age * model.reg_loss(out_age, age).mean()
            loss_int = model.loss(out_int, y_int).mean()
            loss = loss_int

        optimizers['av'].zero_grad()
        optimizers['at'].zero_grad()
        optimizers['vt'].zero_grad()
#         optimizers['spk'].zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), model.hparams.clip)
        if not model.earlystop_av:
            optimizers['av'].step()
            if model.hparams.setting == 'aux':
                av_acc, av_f1, av_wacc = acc_fn(out_av, y_int)
                writer.add_scalar('train/stage2/loss/av_loss', loss_av, model.iter)
                writer.add_scalar('train/stage2/acc/av_acc', av_acc, model.iter)
        if not model.earlystop_at:
            optimizers['at'].step()
            if model.hparams.setting == 'aux':
                at_acc, at_f1, at_wacc = acc_fn(out_at, y_int)
                writer.add_scalar('train/stage2/loss/at_loss', loss_at, model.iter)
                writer.add_scalar('train/stage2/acc/at_acc', at_acc, model.iter)
        if not model.earlystop_vt:
            optimizers['vt'].step()
            if model.hparams.setting == 'aux':
                vt_acc, vt_f1, vt_wacc = acc_fn(out_vt, y_int)
                writer.add_scalar('train/stage2/loss/vt_loss', loss_vt, model.iter)
                writer.add_scalar('train/stage2/acc/vt_acc', vt_acc, model.iter)
        
#         if not model.earlystop_spk:
#             optimizers['spk'].step()        
#             writer.add_scalar('train/loss/spk/spk', loss_spk, model.iter)
#             writer.add_scalar('train/loss/spk/gen', loss_gen, model.iter)
#             writer.add_scalar('train/loss/spk/age', loss_age, model.iter)
        if model.hparams.setting == 'full':
            int_acc, int_f1, int_wacc = acc_fn(out_int, y_int)
            writer.add_scalar('train/loss/int_loss', loss_int, model.iter)
            writer.add_scalar('train/acc/int_acc',int_acc, model.iter)
        
        
    elif model.stage == 3:
        out_int = results
        
        loss_int = model.loss(out_int, y_int).mean()
#         loss_spk = model.hparams.alpha_spk * model.loss(out_spk, spk).mean()
#         loss_gen = model.hparams.alpha_gen * model.loss(out_gen, gen).mean()
#         loss_age = model.hparams.alpha_age * model.reg_loss(out_age, age).mean()
        
        loss = loss_int
        
        optimizers['int'].zero_grad()
#         optimizers['spk'].zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), model.hparams.clip)
        optimizers['int'].step()
#         if not model.earlystop_spk:
#             optimizers['spk'].step()
        int_acc, int_f1, int_wacc = acc_fn(out_int, y_int)
        
#         writer.add_scalar('train/loss/spk/spk', loss_spk, model.iter)
#         writer.add_scalar('train/loss/spk/gen', loss_gen, model.iter)
#         writer.add_scalar('train/loss/spk/age', loss_age, model.iter)
        
        if model.hparams.setting == 'aux':
            writer.add_scalar('train/stage3/loss/int_loss', loss_int, model.iter)
            writer.add_scalar('train/stage3/acc/int_acc',int_acc, model.iter)
        elif model.hparams.setting == 'full':
            writer.add_scalar('train/loss/int_loss', loss_int, model.iter)
            writer.add_scalar('train/acc/int_acc',int_acc, model.iter)
            writer.add_scalar('train/acc/int_f1',int_f1, model.iter)
            writer.add_scalar('train/acc/int_wacc',int_wacc, model.iter)

    model.iter += 1
    
def batchify(batch, device):
    audios = []
    max_len = max([x.shape[0] for x in batch[0]])
    for audio in batch[0]:
        audio = np.pad(audio, ((0,max_len - audio.shape[0]),(0,0)))
        audios.append(audio)
    audios = np.stack(audios)

    imgs = []
    max_len = max([x.shape[0] for x in batch[4]])
    for img in batch[4]:
        img = np.pad(img, ((0,max_len - img.shape[0]),(0,0)))
        imgs.append(img)
    imgs = np.stack(imgs)
    
    texts = []
    max_len = max([x.shape[0] for x in batch[2]])
    for text in batch[2]:
        text = np.pad(text, ((0,max_len - text.shape[0]),(0,0)))
        texts.append(text)
    texts = np.stack(texts)

    audios = torch.tensor(audios).to(device)
    audio_lens = torch.tensor(batch[1]).to(device)
    
    imgs = torch.tensor(imgs).to(device)
    img_lens = torch.tensor(batch[5]).to(device)
    
    texts = torch.tensor(texts).to(device)
    text_lens = torch.tensor(batch[3]).to(device)
    
    bs = audios.shape[0]    
    return audios, audio_lens, texts, text_lens, imgs, img_lens

In [6]:
class HParams():
    img_dim = 512
    audio_dim = 256
    text_dim = 768
    
    audio_h = 64
    img_h = 4
    text_h = 32
    audio_bidirectional = True
    img_bidirectional = True
    text_bidirectional = True
    
    in_drop_img = 0.5
    
    fc1_av = 128
    fc1_at = 128
    fc1_vt = 128
    
    fc2_dim = 512
    fc_dropout = 0.5
    img_layers = 1
    audio_layers = 1
    text_layers = 1
    img_rnn_dropout = 0
    text_rnn_dropout = 0
    post_rnn_dropout = 0.5
    activation = 'gelu2'
    decay_step = 400
    decay_mag = 0.1
    optim_type = 'nadam'
    decay_type = 'none'
    max_lr = 0.05
    onecycle_epochs = 100
    final_div_factor = 1e5
    clip = 10
    
    spk_dim = 10
    num_spks = 100
    
    lr = 1e-3
    reg = 1e-6
    
    audio_lr = lr
    audio_reg = reg
    audio_optim = optim_type
    
    img_lr = lr
    img_reg = reg
    img_optim = optim_type
    
    text_lr = lr
    text_reg = reg
    text_optim = optim_type
    
    spk_lr = lr
    spk_reg = reg
    spk_optim = optim_type
    
    av_lr = lr
    av_reg = reg
    av_optim = optim_type
    
    at_lr = lr
    at_reg = reg
    at_optim = optim_type
    
    vt_lr = lr
    vt_reg = reg
    vt_optim = optim_type
    
    int_lr = lr
    int_reg = reg
    int_optim = optim_type
    
    init_stage = 3
    tune_prev = True
    
    alpha_spk = 2
    alpha_gen = 1
    alpha_age = 0.1
    
    masking = False
    mask_rate = 0.1
    
    setting = 'full'
    
    audio_segment = 20
    audio_step = 20
    img_segment = 6
    img_step = 200
    text_segment = 30
    img_step = 10
    
    update_every = 75
    validate_every = 10
    
    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)
            
def train(**kwargs):
    def evaluate():
        val_iter = iter(val_dataloader)
        model.eval()
        val_batch = len(val_iter)
        # val_batch = 100
        labels, val_preds = [], []
        for i in range(val_batch):
            batch = next(val_iter, None)
            if batch is None:
                val_iter = iter(val_dataloader)
                batch = next(val_iter, None)
        #     ys_int = torch.tensor(batch[5]).to(device)
            labels.extend(batch[6])

            batch = batchify(batch[:6], device)
            max_bs = 256
            N = batch[0].shape[0]
            all_preds = []
            for i in range(0,N, max_bs):
                segment = [x[i:i+max_bs] for x in batch]
                with torch.no_grad():
                    preds = model(segment) 
                all_preds.append(preds)

            if isinstance(all_preds[0], tuple):
                preds = [torch.cat([x[i] for x in all_preds], 0) for i in range(len(all_preds[0]))]
            else:
                preds = torch.cat(all_preds, 0)

            if isinstance(preds, tuple) or isinstance(preds, list):
                preds = [x.mean(0).view(1,-1) for x in preds]
            else:
                preds = preds.mean(0).view(1,-1)

            preds[preds > 0.5] = 1
            preds[preds <= 0.5] = 0
            val_preds.extend(preds.cpu().numpy())
        labels = np.array(labels)
        val_preds = np.array(val_preds)

        from sklearn.metrics import f1_score, accuracy_score
        from sklearn.metrics import confusion_matrix
        from data import emotions
        from collections import Counter

        f1s, waccs = [], []
        for emo in range(labels.shape[1]):
            sublabels = labels[:,emo]
            subpreds = val_preds[:,emo]
            print('---',emotions[emo])
            f1 = f1_score(sublabels, subpreds, average='binary')
            acc = (sublabels == subpreds).mean()
            counts = Counter(sublabels)
            n, p = counts[0], counts[1]
            tn, fp, fn, tp = confusion_matrix(sublabels, subpreds).ravel()
            wacc = (tp * n/p + tn) / (2*n)
            print("F1:", f1, "\nA:", acc, "\nWA:", wacc)
            print(tp,fp,fn,tn)
            f1s.append(f1)
            waccs.append(wacc)
        print(mean(f1s), mean(waccs))
    
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True,
            collate_fn = collates.collate_fn, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
            collate_fn = collates.val_collate_fn, num_workers=0)
    hparams = HParams(epoch_size = len(train_dataloader), **kwargs)
    
    collates.audio_segment = hparams.audio_segment
    collates.audio_step = hparams.audio_step
    collates.img_segment = hparams.img_segment
    collates.img_step = hparams.img_step
    collates.text_segment = hparams.text_segment
    collates.text_step = hparams.img_step

    H = [hparams.audio_h, hparams.img_h, hparams.text_h, hparams.fc1_av, hparams.fc1_at, hparams.fc1_vt, hparams.fc2_dim]
    BI = ['bi' if hparams.audio_bidirectional else 'uni',
          'bi' if hparams.img_bidirectional else 'uni',
          'bi' if hparams.text_bidirectional else 'uni']
    lr = [hparams.audio_lr, hparams.img_lr, hparams.text_lr, hparams.int_lr]
    reg = [hparams.audio_reg, hparams.img_reg, hparams.text_reg, hparams.int_reg]
    cur_datetime = datetime.now().strftime("%m-%d-%H:%M:%S")
    name = "({} {} {}) lr=[{}]-reg=[{}]-H=[{}]-[{}]-img=[{} {}]-audio=[{} {}]-text=[{} {}]".\
        format(hparams.setting[0],
               hparams.tune_prev,
               hparams.init_stage,
               " ".join(map(str,lr)),
               " ".join(map(str,reg)),
               " ".join(map(str,H)), " ".join(BI),
               collates.img_segment, collates.img_step,
               collates.audio_segment, collates.audio_step,
               collates.text_segment, collates.text_step)

    if name:
        name = "%s-%s"%(cur_datetime,name)
    else:
        name = cur_datetime

    device = torch.device('cuda')

    ## Model init !!!
    model, writer = initiate_model(hparams, name)
    optimizers = initiate_optimizers(hparams, model)

    train_iter = iter(train_dataloader)
    data_counter = 0
    if model.hparams.setting == 'aux':
        earlystopper_audio = EarlyStopping()
        earlystopper_img = EarlyStopping()
        earlystopper_text = EarlyStopping()
        earlystopper_av = EarlyStopping()
        earlystopper_at = EarlyStopping()
        earlystopper_vt = EarlyStopping()
        earlystopper_int = EarlyStopping()
    elif model.hparams.setting == 'full':
        earlystopper_stage1 = EarlyStopping()
        earlystopper_stage2 = EarlyStopping()
        earlystopper_stage3 = EarlyStopping()
    earlystopper_spk = EarlyStopping(patience=5)

    print("[Starting]", model.name)
    while True:
        
        if model.iter % (hparams.update_every*hparams.epoch_size//1) == 0:
            train_dataset.update_data()

        model.train()
        batch = next(train_iter, None)
        if batch is None:
            train_iter = iter(train_dataloader)
            batch = next(train_iter, None)


    #     batch = list(batch)
    #     batch[0] += np.random.normal(0,0.1,batch[0].shape)
    #     batch[2] += np.random.normal(0,0.1,batch[2].shape)
    #     batch[3] += np.random.normal(0,0.1,batch[3].shape)


        batch = [torch.tensor(x).to(device) for x in batch]
        res = train_step(model, optimizers, batch, writer)
        if model.iter % (hparams.validate_every*hparams.epoch_size//1) == 0:
            val_iter = iter(val_dataloader)
            model.eval()
            val_batch = 128
            results = []
            batch_preds = []
            batch_ys = []
            for i in range(val_batch):
                batch = next(val_iter, None)
                if batch is None:
                    val_iter = iter(val_dataloader)
                    batch = next(val_iter, None)
                ys_int = torch.tensor(batch[6]).to(device)

                batch = batchify(batch[:6], device)
                max_bs = 256
                N = batch[0].shape[0]
                all_preds = []
                for i in range(0,N, max_bs):
                    segment = [x[i:i+max_bs] for x in batch]
                    with torch.no_grad():
                        preds = model(segment) 
                    all_preds.append(preds)

                if isinstance(all_preds[0], tuple):
                    preds = [torch.cat([x[i] for x in all_preds], 0) for i in range(len(all_preds[0]))]
                else:
                    preds = torch.cat(all_preds, 0)

                if isinstance(preds, tuple) or isinstance(preds, list):
                    preds = [x.mean(0).view(1,-1) for x in preds]
                else:
                    preds = preds.mean(0).view(1,-1)
                batch_preds.append(preds)
                batch_ys.append(ys_int)

            if isinstance(batch_preds[0], tuple) or isinstance(batch_preds[0], list):
                preds = [torch.cat(x,0) for x in zip(*batch_preds)]
            else:
                preds = torch.cat(batch_preds,0)
            ys_int = torch.cat(batch_ys,0)

            if model.hparams.setting == 'aux' and (model.stage == 1 or model.stage == 2):
                if model.stage == 1:
                    with torch.no_grad():
                        loss_audio = model.loss(preds[0], ys_int).mean().item()
                        loss_img = model.loss(preds[1], ys_int).mean().item()
                        loss_text = model.loss(preds[2], ys_int).mean().item()
    #                         loss_spk = model.hparams.alpha_spk * model.loss(preds[3], spk).mean()
    #                         loss_gen = model.hparams.alpha_gen * model.loss(preds[4], gen).mean()
    #                         loss_age = model.hparams.alpha_age * model.reg_loss(preds[5], age).mean()
                    audio_acc, audio_f1, audio_wacc = acc_fn(preds[0], ys_int)
                    img_acc, img_f1, img_wacc = acc_fn(preds[1], ys_int)
                    text_acc, test_f1, text_wacc = acc_fn(preds[2], ys_int)

                    results.append((loss_audio, loss_img, loss_text,
                                   audio_acc, img_acc, text_acc,
    #                                    loss_spk, loss_gen, loss_age
                                   ))
                elif model.stage == 2:
                    with torch.no_grad():
                        loss_av = model.loss(preds[0], ys_int).mean().item()
                        loss_at = model.loss(preds[1], ys_int).mean().item()
                        loss_vt = model.loss(preds[2], ys_int).mean().item()
    #                         loss_spk = model.hparams.alpha_spk * model.loss(preds[3], spk).mean()
    #                         loss_gen = model.hparams.alpha_gen * model.loss(preds[4], gen).mean()
    #                         loss_age = model.hparams.alpha_age * model.reg_loss(preds[5], age).mean()

                    av_acc, av_f1, av_wacc = acc_fn(preds[0], ys_int)
                    at_acc, at_f1, at_wacc = acc_fn(preds[1], ys_int)
                    vt_acc, vt_f1, av_wacc = acc_fn(preds[2], ys_int)

                    results.append((loss_av, loss_at, loss_vt,
                                   av_acc, at_acc, vt_acc,
    #                                    loss_spk, loss_gen, loss_age
                                   ))
            else:
                with torch.no_grad():
                    loss_int = model.loss(preds, ys_int).mean().item()
    #                     loss_spk = model.hparams.alpha_spk * model.loss(preds[1], spk).mean()
    #                     loss_gen = model.hparams.alpha_gen * model.loss(preds[2], gen).mean()
    #                     loss_age = model.hparams.alpha_age * model.reg_loss(preds[3], age).mean()

                int_acc, int_f1, int_wacc = acc_fn(preds, ys_int)
                results.append((loss_int, int_acc, int_f1, int_wacc
    #                                 loss_spk, loss_gen, loss_age
                               ))

            results = [mean([res[i] for res in results]) for i in range(len(results[0]))]

            if model.hparams.setting == 'aux':
                if model.stage == 1:
                    loss_audio, loss_img, loss_text, audio_acc, img_acc, text_acc = results

                    if not model.earlystop_audio and earlystopper_audio(loss_audio):
                        model.earlystop_audio = True
                        print("[%d] Stopped audio"%model.iter)
                    if not model.earlystop_img and earlystopper_img(loss_img):
                        model.earlystop_img = True
                        print("[%d] Stopped img"%model.iter)
                    if not model.earlystop_text and earlystopper_text(loss_text):
                        model.earlystop_text = True
                        print("[%d] Stopped text"%model.iter)
    #                 if earlystopper_spk(loss_spk):
    #                     model.earlystop_spk = True

                    if not model.earlystop_audio:
                        writer.add_scalar('val/stage1/loss/audio_loss', loss_audio, model.iter)
                        writer.add_scalar('val/stage1/acc/audio_acc',audio_acc, model.iter)
                    if not model.earlystop_img:
                        writer.add_scalar('val/stage1/loss/img_loss', loss_img, model.iter)
                        writer.add_scalar('val/stage1/acc/img_acc',img_acc, model.iter)
                    if not model.earlystop_text:
                        writer.add_scalar('val/stage1/loss/text_loss', loss_text, model.iter)
                        writer.add_scalar('val/stage1/acc/text_acc',text_acc, model.iter)

                    if model.earlystop_audio and model.earlystop_img and model.earlystop_text:
                        model.stage = 2
                        model.earlystop_spk = False
                        earlystopper_spk.counter = 0
                elif model.stage == 2:
                    loss_av, loss_at, loss_vt, av_acc, at_acc, vt_acc = results

                    if not model.earlystop_av and earlystopper_av(loss_av):
                        model.earlystop_av = True
                        print("[%d] Stopped av"%model.iter)
                    if not model.earlystop_at and earlystopper_at(loss_at):
                        model.earlystop_at = True
                        print("[%d] Stopped at"%model.iter)
                    if not model.earlystop_vt and earlystopper_vt(loss_vt):
                        model.earlystop_vt = True
                        print("[%d] Stopped vt"%model.iter)
    #                 if earlystopper_spk(loss_spk):
    #                     model.earlystop_spk = True

                    if not model.earlystop_av:
                        writer.add_scalar('val/stage2/loss/av_loss', loss_av, model.iter)
                        writer.add_scalar('val/stage2/acc/av_acc',av_acc, model.iter)
                    if not model.earlystop_at:
                        writer.add_scalar('val/stage2/loss/at_loss', loss_at, model.iter)
                        writer.add_scalar('val/stage2/acc/at_acc',at_acc, model.iter)
                    if not model.earlystop_vt:
                        writer.add_scalar('val/stage2/loss/vt_loss', loss_vt, model.iter)
                        writer.add_scalar('val/stage2/acc/vt_acc',vt_acc, model.iter)

                    if model.earlystop_av and model.earlystop_at and model.earlystop_vt:
                        model.stage = 3
                elif model.stage == 3:
                    loss_int, int_acc, int_f1, int_wacc = results

                    writer.add_scalar('val/stage3/loss/int_loss', loss_int, model.iter)
                    writer.add_scalar('val/stage3/acc/int_acc',int_acc, model.iter)
                    writer.add_scalar('val/stage3/acc/int_f1', int_f1, model.iter)
                    writer.add_scalar('val/stage3/acc/int_wacc',int_wacc, model.iter)


                    if earlystopper_int(loss_int):
                        print("[%d] Stopped int"%model.iter)
    #                     train_dataset.update_data()
    #                     model.reset_state()
    #                     earlystopper_audio = EarlyStopping()
    #                     earlystopper_img = EarlyStopping()
    #                     earlystopper_text = EarlyStopping()
    #                     earlystopper_av = EarlyStopping()
    #                     earlystopper_at = EarlyStopping()
    #                     earlystopper_vt = EarlyStopping()
    #                     earlystopper_int = EarlyStopping()
                        break

            elif model.hparams.setting == 'full':
                loss_int, int_acc, int_f1, int_wacc = results
                if model.stage == 1:
    #                 if earlystopper_spk(loss_spk):
    #                     model.earlystop_spk = True
                    if earlystopper_stage1(loss_int):
                        model.stage = 2
                        print("[%d] Finished stage 1"%model.iter)
                        model.earlystop_spk = False
                        earlystopper_spk.counter = 0
                elif model.stage == 2:
    #                 if earlystopper_spk(loss_spk):
    #                     model.earlystop_spk = True
                    if earlystopper_stage2(loss_int):
                        model.stage = 3
                        print("[%d] Finished stage 2"%model.iter)
                elif model.stage == 3:
    #                 if earlystopper_spk(loss_spk):
    #                     model.earlystop_spk = True
                    if earlystopper_stage3(loss_int):
                        print("[%d] Finished stage 3"%model.iter)
    #                     train_dataset.update_data()
    #                     model.reset_state()
    #                     earlystopper_audio = EarlyStopping()
    #                     earlystopper_img = EarlyStopping()
    #                     earlystopper_text = EarlyStopping()
    #                     earlystopper_av = EarlyStopping()
    #                     earlystopper_at = EarlyStopping()
    #                     earlystopper_vt = EarlyStopping()
    #                     earlystopper_int = EarlyStopping()
                        break
                writer.add_scalar('val/loss/int_loss', loss_int, model.iter)
                writer.add_scalar('val/acc/int_acc',int_acc, model.iter)
                writer.add_scalar('val/acc/int_f1', int_f1, model.iter)
                writer.add_scalar('val/acc/int_wacc',int_wacc, model.iter)

    #         if not model.earlystop_spk:
    #             writer.add_scalar('val/loss/spk/spk', loss_spk, model.iter)
    #             writer.add_scalar('val/loss/spk/gen', loss_gen, model.iter)
    #             writer.add_scalar('val/loss/spk/age', loss_age, model.iter)

    evaluate()

In [40]:
train(setting='full', init_stage = 3, tune_prev = True)
train(setting='full', init_stage = 1, tune_prev = False)

  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-00:17:00-(f True 3) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 10]-audio=[20 20]-text=[30 10]


100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
100%|██████████| 500/500 [02:51<00:00,  2.91it/s]
100%|██████████| 500/500 [02:43<00:00,  3.07it/s]
100%|██████████| 500/500 [02:54<00:00,  2.86it/s]
100%|██████████| 500/500 [02:47<00:00,  2.98it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:53<00:00,  2.88it/s]
100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
100%|██████████| 500/500 [02:53<00:00,  2.89it/s]
100%|██████████| 500/500 [02:44<00:00,  3.04it/s]
100%|██████████| 500/500 [02:50<00:00,  2.94it/s]
100%|██████████| 500/500 [02:51<00:00,  2.91it/s]
100%|██████████| 500/500 [02:45<00:00,  3.02it/s]
100%|██████████| 500/500 [02:45<00:00,  3.02it/s]
100%|██████████| 500/500 [02:43<00:00,  3.07it/s]
100%|██████████| 500/500 [02:44<00:00,  3.03it/s]
100%|██████████| 500/500 [02:45<00:00,  3.03it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:47<00:00,  2.98it/s]


[15680] Finished stage 3
--- hap
F1: 0.9081455805892548 
A: 0.8398791540785498 
WA: 0.6221664178091488
262 37 16 16
--- sad
F1: 0.6458923512747875 
A: 0.622356495468278 
WA: 0.6205798374698002
114 65 60 92
--- fea
F1: 0.40404040404040403 
A: 0.4652567975830816 
WA: 0.5227396761870111
60 145 32 94
--- sur
F1: 0.5691056910569106 
A: 0.6797583081570997 
WA: 0.681565202500214
70 74 32 155
--- ang
F1: 0.4467005076142132 
A: 0.6706948640483383 
WA: 0.6027149321266968
44 43 66 178
--- dis
F1: 0.5533980582524272 
A: 0.7220543806646526 
WA: 0.6778303917348256
57 48 44 182
0.5878804321379995 0.6212660763046162


  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-01:50:50-(f False 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 10]-audio=[20 20]-text=[30 10]


100%|██████████| 500/500 [02:50<00:00,  2.94it/s]
100%|██████████| 500/500 [02:34<00:00,  3.23it/s]
100%|██████████| 500/500 [02:52<00:00,  2.90it/s]
100%|██████████| 500/500 [02:48<00:00,  2.97it/s]
100%|██████████| 500/500 [02:43<00:00,  3.07it/s]
100%|██████████| 500/500 [02:47<00:00,  2.98it/s]
100%|██████████| 500/500 [02:39<00:00,  3.13it/s]
100%|██████████| 500/500 [02:43<00:00,  3.05it/s]
100%|██████████| 500/500 [02:37<00:00,  3.18it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:46<00:00,  3.00it/s]
100%|██████████| 500/500 [02:50<00:00,  2.94it/s]
100%|██████████| 500/500 [02:41<00:00,  3.10it/s]
100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
100%|██████████| 500/500 [02:42<00:00,  3.08it/s]
100%|██████████| 500/500 [02:51<00:00,  2.92it/s]
100%|██████████| 500/500 [02:37<00:00,  3.18it/s]
100%|██████████| 500/500 [02:34<00:00,  3.23it/s]
100%|██████████| 500/500 [02:40<00:00,  3.12it/s]


[22680] Finished stage 1


100%|██████████| 500/500 [02:39<00:00,  3.14it/s]
100%|██████████| 500/500 [02:48<00:00,  2.97it/s]
100%|██████████| 500/500 [02:44<00:00,  3.05it/s]
100%|██████████| 500/500 [02:39<00:00,  3.14it/s]
100%|██████████| 500/500 [02:42<00:00,  3.07it/s]
100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
100%|██████████| 500/500 [02:43<00:00,  3.07it/s]
100%|██████████| 500/500 [02:35<00:00,  3.22it/s]
100%|██████████| 500/500 [02:48<00:00,  2.96it/s]
100%|██████████| 500/500 [02:37<00:00,  3.18it/s]
100%|██████████| 500/500 [02:40<00:00,  3.11it/s]


[29120] Finished stage 2


100%|██████████| 500/500 [02:39<00:00,  3.13it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]
100%|██████████| 500/500 [02:45<00:00,  3.02it/s]
100%|██████████| 500/500 [02:43<00:00,  3.05it/s]
100%|██████████| 500/500 [02:41<00:00,  3.10it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]
100%|██████████| 500/500 [02:52<00:00,  2.91it/s]
100%|██████████| 500/500 [02:43<00:00,  3.05it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]
100%|██████████| 500/500 [02:37<00:00,  3.17it/s]
100%|██████████| 500/500 [02:40<00:00,  3.12it/s]
100%|██████████| 500/500 [02:39<00:00,  3.13it/s]
100%|██████████| 500/500 [02:41<00:00,  3.09it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]
100%|██████████| 500/500 [02:53<00:00,  2.89it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]
100%|██████████| 500/500 [02:46<00:00,  3.00it/s]
100%|██████████| 500/500 [02:44<00:00,  3.03it/s]
100%|██████████| 500/500 [02:42<00:00,  3.08it/s]


[57120] Finished stage 3
--- hap
F1: 0.9093904448105437 
A: 0.8338368580060423 
WA: 0.49640287769784175
276 53 2 0
--- sad
F1: 0.6773455377574371 
A: 0.5740181268882175 
WA: 0.5590453181052786
148 115 26 42
--- fea
F1: 0.41322314049586784 
A: 0.3564954682779456 
WA: 0.4975668546479898
75 196 17 43
--- sur
F1: 0.5467128027681661 
A: 0.6042296072507553 
WA: 0.6514470416987757
79 108 23 121
--- ang
F1: 0.38738738738738737 
A: 0.5891238670694864 
WA: 0.5393459481694776
43 69 67 152
--- dis
F1: 0.5546218487394958 
A: 0.6797583081570997 
WA: 0.6723848471803702
66 71 35 159
0.581446860326483 0.5693654812499557


  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-07:18:11-(a False 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 10]-audio=[20 20]-text=[30 10]


 48%|████▊     | 242/500 [01:14<01:19,  3.25it/s]


KeyboardInterrupt: 

In [46]:
train(setting='aux', init_stage = 1, tune_prev = False)
train(setting='aux', init_stage = 1, tune_prev = True)

  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-13:47:05-(a False 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 10]-audio=[20 20]-text=[30 10]


100%|██████████| 500/500 [02:47<00:00,  2.98it/s]
100%|██████████| 500/500 [02:46<00:00,  3.00it/s]
100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
100%|██████████| 500/500 [02:47<00:00,  2.98it/s]


[2320] Stopped img


100%|██████████| 500/500 [02:45<00:00,  3.01it/s]
100%|██████████| 500/500 [02:42<00:00,  3.07it/s]
100%|██████████| 500/500 [02:35<00:00,  3.21it/s]
100%|██████████| 500/500 [02:48<00:00,  2.96it/s]


[4240] Stopped text
[4400] Stopped audio


100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
100%|██████████| 500/500 [02:42<00:00,  3.08it/s]
100%|██████████| 500/500 [02:50<00:00,  2.94it/s]


[6160] Stopped av
[6160] Stopped vt
[6320] Stopped at


100%|██████████| 500/500 [02:53<00:00,  2.89it/s]
100%|██████████| 500/500 [02:42<00:00,  3.08it/s]
100%|██████████| 500/500 [02:47<00:00,  2.98it/s]
100%|██████████| 500/500 [02:53<00:00,  2.89it/s]
100%|██████████| 500/500 [02:39<00:00,  3.13it/s]


[9440] Stopped int
--- hap
F1: 0.8975265017667845 
A: 0.824773413897281 
WA: 0.6360798153929685
254 34 24 19
--- sad
F1: 0.6666666666666667 
A: 0.6163141993957704 
WA: 0.6101654586719378
127 80 47 77
--- fea
F1: 0.42585551330798477 
A: 0.5438066465256798 
WA: 0.5637620520283791
56 115 36 124
--- sur
F1: 0.5714285714285713 
A: 0.6737160120845922 
WA: 0.6826354996146929
72 78 30 151
--- ang
F1: 0.4585365853658537 
A: 0.6646525679758308 
WA: 0.6050390785684904
47 48 63 173
--- dis
F1: 0.5538461538461539 
A: 0.7371601208459214 
WA: 0.6803702109341369
54 40 47 190
0.595643332063669 0.629675352535101


  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-14:47:39-(a True 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 10]-audio=[20 20]-text=[30 10]


100%|██████████| 500/500 [02:43<00:00,  3.06it/s]
100%|██████████| 500/500 [02:52<00:00,  2.90it/s]
100%|██████████| 500/500 [02:41<00:00,  3.10it/s]
100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
100%|██████████| 500/500 [02:41<00:00,  3.09it/s]
100%|██████████| 500/500 [02:37<00:00,  3.17it/s]
100%|██████████| 500/500 [02:52<00:00,  2.90it/s]


[4000] Stopped text


100%|██████████| 500/500 [02:40<00:00,  3.12it/s]
100%|██████████| 500/500 [02:45<00:00,  3.01it/s]
100%|██████████| 500/500 [02:46<00:00,  3.00it/s]
100%|██████████| 500/500 [02:41<00:00,  3.10it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]


[7600] Stopped audio


100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]


[8480] Stopped img


100%|██████████| 500/500 [02:38<00:00,  3.16it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:46<00:00,  2.99it/s]
100%|██████████| 500/500 [02:44<00:00,  3.03it/s]
100%|██████████| 500/500 [02:41<00:00,  3.09it/s]


[11760] Stopped av


100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
100%|██████████| 500/500 [02:41<00:00,  3.10it/s]
100%|██████████| 500/500 [02:45<00:00,  3.02it/s]
100%|██████████| 500/500 [02:44<00:00,  3.04it/s]
100%|██████████| 500/500 [02:42<00:00,  3.08it/s]


[14960] Stopped at
[14960] Stopped vt


100%|██████████| 500/500 [02:44<00:00,  3.05it/s]
100%|██████████| 500/500 [02:43<00:00,  3.05it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]
100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
100%|██████████| 500/500 [02:51<00:00,  2.92it/s]
100%|██████████| 500/500 [02:40<00:00,  3.12it/s]
100%|██████████| 500/500 [02:45<00:00,  3.03it/s]
100%|██████████| 500/500 [02:41<00:00,  3.09it/s]


[19280] Stopped int
--- hap
F1: 0.8876811594202898 
A: 0.8126888217522659 
WA: 0.6670625763540111
245 29 33 24
--- sad
F1: 0.6297376093294461 
A: 0.6163141993957704 
WA: 0.61607731166264
108 61 66 96
--- fea
F1: 0.3835616438356165 
A: 0.4561933534743202 
WA: 0.5030925959614335
56 144 36 95
--- sur
F1: 0.4973544973544973 
A: 0.7129909365558912 
WA: 0.6430559123212604
47 40 55 189
--- ang
F1: 0.425531914893617 
A: 0.6737160120845922 
WA: 0.5958453311394488
40 38 70 183
--- dis
F1: 0.5319148936170213 
A: 0.7341389728096677 
WA: 0.6670899698665519
50 37 51 193
0.5592969530750814 0.6153706162175576


  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-16:52:09-(a False 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[16 1]-audio=[20 1]-text=[30 1]


100%|██████████| 500/500 [02:52<00:00,  2.90it/s]


RuntimeError: CUDA out of memory. Tried to allocate 3.32 GiB (GPU 0; 7.93 GiB total capacity; 2.45 GiB already allocated; 2.64 GiB free; 2.29 GiB cached)

In [8]:
train(setting='aux', init_stage = 3, tune_prev = True)

  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-19:53:56-(a True 3) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 10]-audio=[20 20]-text=[30 10]


100%|██████████| 500/500 [02:53<00:00,  2.88it/s]
100%|██████████| 500/500 [02:45<00:00,  3.02it/s]
100%|██████████| 500/500 [02:41<00:00,  3.09it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]
100%|██████████| 500/500 [02:48<00:00,  2.97it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]


[3280] Stopped int
--- hap
F1: 0.8700173310225303 
A: 0.7734138972809668 
WA: 0.49860866024161804
251 48 27 5
--- sad
F1: 0.562300319488818 
A: 0.5861027190332326 
WA: 0.5904531810527857
88 51 86 106
--- fea
F1: 0.36363636363636365 
A: 0.5770392749244713 
WA: 0.5332908859377843
40 88 52 151
--- sur
F1: 0.37575757575757573 
A: 0.6888217522658611 
WA: 0.5820917886805378
31 32 71 197
--- ang
F1: 0.2967741935483871 
A: 0.6706948640483383 
WA: 0.554771698889346
23 22 87 199
--- dis
F1: 0.4268292682926829 
A: 0.716012084592145 
WA: 0.612397761515282
35 28 66 202
0.4825525086243929 0.561935662719559


In [7]:
train(setting='aux', init_stage = 1, tune_prev = False,
     audio_segment=20, audio_step=1,
     img_segment=16, img_step=1,
     text_segment=30, text_step=1)

  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-18:35:26-(a False 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[16 1]-audio=[20 1]-text=[30 1]


100%|██████████| 500/500 [03:11<00:00,  2.61it/s]
100%|██████████| 500/500 [02:57<00:00,  2.82it/s]
100%|██████████| 500/500 [02:54<00:00,  2.87it/s]
100%|██████████| 500/500 [02:43<00:00,  3.06it/s]


[2320] Stopped audio
[2320] Stopped img


100%|██████████| 500/500 [02:52<00:00,  2.90it/s]
100%|██████████| 500/500 [02:43<00:00,  3.07it/s]
100%|██████████| 500/500 [03:02<00:00,  2.74it/s]
100%|██████████| 500/500 [02:48<00:00,  2.97it/s]
100%|██████████| 500/500 [02:49<00:00,  2.94it/s]


[4960] Stopped text


100%|██████████| 500/500 [02:56<00:00,  2.83it/s]
100%|██████████| 500/500 [02:57<00:00,  2.82it/s]
100%|██████████| 500/500 [02:45<00:00,  3.03it/s]
100%|██████████| 500/500 [02:43<00:00,  3.05it/s]
100%|██████████| 500/500 [02:44<00:00,  3.04it/s]
100%|██████████| 500/500 [02:55<00:00,  2.85it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]


[9520] Stopped at


  0%|          | 0/500 [00:00<?, ?it/s]

[9600] Stopped av


100%|██████████| 500/500 [02:44<00:00,  3.05it/s]
100%|██████████| 500/500 [02:43<00:00,  3.05it/s]


[10400] Stopped vt


100%|██████████| 500/500 [02:54<00:00,  2.87it/s]
100%|██████████| 500/500 [02:42<00:00,  3.08it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:39<00:00,  3.13it/s]
100%|██████████| 500/500 [02:47<00:00,  2.98it/s]
100%|██████████| 500/500 [02:47<00:00,  2.99it/s]


[14160] Stopped int
--- hap
F1: 0.8782287822878229 
A: 0.8006042296072508 
WA: 0.6827745350889101
238 26 40 27
--- sad
F1: 0.6855670103092784 
A: 0.6314199395770392 
WA: 0.6242221246064866
133 81 41 76
--- fea
F1: 0.4565217391304348 
A: 0.5468277945619335 
WA: 0.5892532290340186
63 121 29 118
--- sur
F1: 0.5502183406113538 
A: 0.6888217522658611 
WA: 0.6690855381453892
63 64 39 165
--- ang
F1: 0.4818181818181818 
A: 0.6555891238670695 
WA: 0.6119498148909914
53 57 57 164
--- dis
F1: 0.5844748858447489 
A: 0.7250755287009063 
WA: 0.6994403788204908
64 54 37 176
0.6061381566669701 0.6461209367643811


In [9]:
train(setting='aux', init_stage = 1, tune_prev = False,
     audio_segment=20, audio_step=1,
     img_segment=6, img_step=1,
     text_segment=30, text_step=1)

  0%|          | 0/500 [00:00<?, ?it/s]

[Starting] 12-06-20:46:43-(a False 1) lr=[0.001 0.001 0.001 0.001]-reg=[1e-06 1e-06 1e-06 1e-06]-H=[64 4 32 128 128 128 512]-[bi bi bi]-img=[6 1]-audio=[20 1]-text=[30 1]


100%|██████████| 500/500 [02:49<00:00,  2.96it/s]
100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

[2400] Stopped img


100%|██████████| 500/500 [02:42<00:00,  3.07it/s]
100%|██████████| 500/500 [02:38<00:00,  3.15it/s]
100%|██████████| 500/500 [02:46<00:00,  3.00it/s]
100%|██████████| 500/500 [02:49<00:00,  2.95it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]


[4880] Stopped audio


100%|██████████| 500/500 [02:47<00:00,  2.99it/s]
100%|██████████| 500/500 [02:39<00:00,  3.13it/s]


[6400] Stopped text


100%|██████████| 500/500 [02:45<00:00,  3.01it/s]
100%|██████████| 500/500 [02:50<00:00,  2.93it/s]
100%|██████████| 500/500 [02:52<00:00,  2.90it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]


[8960] Stopped at


100%|██████████| 500/500 [02:44<00:00,  3.05it/s]
100%|██████████| 500/500 [02:44<00:00,  3.03it/s]


[10000] Stopped av


100%|██████████| 500/500 [02:44<00:00,  3.03it/s]
100%|██████████| 500/500 [02:45<00:00,  3.03it/s]


[11200] Stopped vt


100%|██████████| 500/500 [02:48<00:00,  2.97it/s]
100%|██████████| 500/500 [02:46<00:00,  3.01it/s]
100%|██████████| 500/500 [02:52<00:00,  2.91it/s]
100%|██████████| 500/500 [02:36<00:00,  3.20it/s]


[13680] Stopped int
--- hap
F1: 0.8884826325411335 
A: 0.8157099697885196 
WA: 0.6917673408443058
243 26 35 27
--- sad
F1: 0.6738544474393532 
A: 0.6344410876132931 
WA: 0.6298960392415257
125 72 49 85
--- fea
F1: 0.39416058394160586 
A: 0.4984894259818731 
WA: 0.5256958340913226
54 128 38 111
--- sur
F1: 0.5418719211822661 
A: 0.7190332326283988 
WA: 0.6691711619145475
55 46 47 183
--- ang
F1: 0.4153005464480874 
A: 0.676737160120846 
WA: 0.5935417523652818
38 35 72 186
--- dis
F1: 0.4971751412429379 
A: 0.7311178247734139 
WA: 0.6482565647869135
44 32 57 198
0.568474212132564 0.6263881155406495


In [25]:
train_dataset.update_data()

100%|██████████| 450/450 [02:35<00:00,  2.90it/s]


In [35]:
class Collates():
    audio_segment = 30
    audio_step = 1
    img_segment = 30
    text_segment = 30
    text_step = 1
    img_step = 1
    fixed_size = False
    
    def collate_fn(self, batch):
      audios, texts, imgs, ys_int = [list(x) for x in zip(*batch)]
    
#       audio_mean = np.array([audio.mean(0) for audio in audios])
#       img_mean = np.array([img.mean(0) for img in imgs])
    
      audio_len = []
      for i, audio in enumerate(audios):
#         start_id = np.random.randint(0,max(1,audio.shape[0] - self.audio_step*self.audio_segment + 1))
        start_id = np.random.randint(0,audio.shape[0])
        audios[i] = audio[start_id:start_id + self.audio_step*self.audio_segment:self.audio_step]
        
      for i, img in enumerate(imgs):
#         start_id = np.random.randint(0, max(img.shape[0] - self.img_step*self.img_segment + 1, 1))
        start_id = np.random.randint(0, img.shape[0])
        imgs[i] = img[start_id:start_id + self.img_step*self.img_segment:self.img_step]
        
      for i, text in enumerate(texts):
#         start_id = np.random.randint(0, max(img.shape[0] - self.img_step*self.img_segment + 1, 1))
        start_id = np.random.randint(0, text.shape[0])
        texts[i] = text[start_id:start_id + self.text_step*self.text_segment:self.text_step]

#       text_lens = [text.shape[0] for text in texts]
#       max_text_len = max(text_lens)
#       for i,text in enumerate(texts):
#         n = text_lens[i]
#         texts[i] = np.pad(text, ((0,max_text_len-n),(0,0)), mode='wrap')

      img_lens = [img.shape[0] for img in imgs]
      if self.fixed_size:
        max_img_len = self.img_segment
      else:
        max_img_len = max(img_lens)
      for i,img in enumerate(imgs):
        n = img_lens[i]
        imgs[i] = np.pad(img, ((0,max_img_len-n),(0,0)), mode='wrap')

      audio_lens = [audio.shape[0] for audio in audios]
      if self.fixed_size:
        max_audio_len = self.audio_segment
      else:
        max_audio_len = max(audio_lens)
      for i,audio in enumerate(audios):
        n = audio_lens[i]
        audios[i] = np.pad(audio, ((0,max_audio_len-n),(0,0)), mode='wrap')
        
      text_lens = [text.shape[0] for text in texts]
      if self.fixed_size:
        max_text_len = self.text_segment
      else:
        max_text_len = max(text_lens)
      for i,text in enumerate(texts):
        n = text_lens[i]
        texts[i] = np.pad(text, ((0,max_text_len-n),(0,0)), mode='wrap')

        
#       audio_mean = np.repeat(np.expand_dims(audio_mean, 1), max_audio_len, 1)
#       img_mean = np.repeat(np.expand_dims(img_mean, 1), max_img_len, 1)
        
      audios = np.stack(audios)
#       audios = np.concatenate([audios, audio_mean], -1)
      imgs = np.stack(imgs)
#       imgs = np.concatenate([imgs, img_mean], -1)
      texts = np.stack(texts)

        
      res = audios, audio_lens, \
        texts, text_lens, \
        imgs, img_lens, np.array(ys_int).astype('float32')
      return res

    def val_collate_fn(self, batch):
      entry = batch[0]
      audio = entry[0]
      text = entry[1]
      img = entry[2]
      if len(entry) == 4:
        ys_int = np.array([entry[3]]).astype('float32')
      else:
        ys_int = None

      audios, imgs, texts = [], [], []
      img_lens, audio_lens, text_lens = [], [], []
      audio_n = audio.shape[0]
      img_n = img.shape[0]
      text_n = text.shape[0]
        
#       audio_mean = audio.mean(0)
#       img_mean = img.mean(0)

      if self.audio_step > 1:
        audio_off_step = self.audio_step
      else:
        audio_off_step = audio_n//10
      if self.img_step > 1:
        img_off_step = self.img_step
      else:
        img_off_step = img_n//10
      if self.text_step > 1:
        text_off_step = self.text_step
      else:
        text_off_step = text_n//10

      i = 0
      while True:
        audio_offset = i*audio_off_step
        img_offset = i*img_off_step
        text_offset = i*text_off_step

        
        if audio_offset >= audio_n and img_offset >= img_n and text_offset >= text_n:
            break

        if audio_offset < audio_n:
          audio_segment = audio[audio_offset:audio_offset + self.audio_step*self.audio_segment:self.audio_step, :]
        else:
#           start_id = np.random.randint(0, max(1,audio_n - self.audio_step*self.audio_segment + 1))
          start_id = np.random.randint(0, audio_n)
          audio_segment = audio[start_id:start_id+self.audio_step*self.audio_segment:self.audio_step, :]

        if img_offset < img_n:
            img_segment = img[img_offset:img_offset + self.img_step*self.img_segment:self.img_step, :]
        else:
#           start_id = np.random.randint(0, max(img_n - self.img_step*self.img_segment + 1, 1))
          start_id = np.random.randint(0, img_n)
          img_segment = img[start_id:start_id+self.img_step*self.img_segment:self.img_step, :]
        
        if text_offset < text_n:
            text_segment = text[text_offset:text_offset + self.text_step*self.text_segment:self.text_step, :]
        else:
#           start_id = np.random.randint(0, max(text_n - self.text_step*self.text_segment + 1, 1))
          start_id = np.random.randint(0, text_n)
          text_segment = text[start_id:start_id+self.text_step*self.text_segment:self.text_step, :]
        
#         audio_segment = np.concatenate([audio, np.repeat(np.expand_dims(audio_mean, 0), audio.shape[0], 0)], -1)
#         img_segment = np.concatenate([img, np.repeat(np.expand_dims(img_mean, 0), img.shape[0], 0)], -1)
    
        img_lens.append(img_segment.shape[0])
        audio_lens.append(audio_segment.shape[0])
        text_lens.append(text_segment.shape[0])
        
        
        if self.fixed_size:
          max_img_len = self.img_segment
          n = img_lens[-1]
          img_segment = np.pad(img_segment, ((0,max_img_len-n),(0,0)), mode='wrap')

        if self.fixed_size:
          max_audio_len = self.audio_segment
          n = audio_lens[-1]
          audio_segment = np.pad(audio_segment, ((0,max_audio_len-n),(0,0)), mode='wrap')
        
        audios.append(audio_segment)
        imgs.append(img_segment)
        texts.append(text_segment)

        i += 1



      res = audios, audio_lens, texts, text_lens, imgs, img_lens, ys_int
      return res

10262 3069 559
1026 306 55
0 0 0
1026 306 55
2052 612 110
3078 918 165
4104 1224 220
5130 1530 275
6156 1836 330
7182 2142 385
8208 2448 440
9234 2754 495
10260 3060 550
11286 3366 605


11

--- hap
F1: 0.9108910891089109 
A: 0.8368580060422961 
WA: 0.5058368399619927
276 52 2 1
--- sad
F1: 0.1913875598086124 
A: 0.48942598187311176 
WA: 0.5097005637308734
20 15 154 142
--- fea
F1: 0.0 
A: 0.7220543806646526 
WA: 0.5
0 0 92 239
--- sur
F1: 0.4444444444444445 
A: 0.6978851963746223 
WA: 0.6131089990581385
40 38 62 191
--- ang
F1: 0.23188405797101444 
A: 0.6797583081570997 
WA: 0.5455779514603044
16 12 94 209
--- dis
F1: 0.5662100456621004 
A: 0.7129909365558912 
WA: 0.6851915626345243
62 56 39 174
0.3908028661658471 0.5599026528076388


In [None]:
w = [1686, 759, 336, 176, 497, 332]
w2 = [max(1400/x, 1) for x in w]
# w2 = [max(w)/x for x in w]
w3 = w2 / np.linalg.norm(w2, 1)
[round(x,3) for x in w2]
# w2

In [None]:
class Collates():
    audio_segment = 30
    audio_step = 15
    img_segment = 30
    img_step = 15
    fixed_size = False
    
    def collate_fn(self, batch):
      audios, texts, imgs, ys_int = [list(x) for x in zip(*batch)]
    
#       audio_mean = np.array([audio.mean(0) for audio in audios])
#       img_mean = np.array([img.mean(0) for img in imgs])
    
      audio_len = []
      for i, audio in enumerate(audios):
        start_id = np.random.randint(0,max(1,audio.shape[0] - self.audio_step*self.audio_segment + 1))
        start_id = np.random.randint(0,audio.shape[0])
        audios[i] = audio[start_id:start_id + self.audio_step*self.audio_segment:self.audio_step]
        
      for i, img in enumerate(imgs):
#         start_id = np.random.randint(0, max(img.shape[0] - self.img_step*self.img_segment + 1, 1))
        start_id = np.random.randint(0, img.shape[0])
        imgs[i] = img[start_id:start_id + self.img_step*self.img_segment:self.img_step]

#       text_lens = [text.shape[0] for text in texts]
#       max_text_len = max(text_lens)
#       for i,text in enumerate(texts):
#         n = text_lens[i]
#         texts[i] = np.pad(text, ((0,max_text_len-n),(0,0)), mode='wrap')

      text_lens = None

      img_lens = [img.shape[0] for img in imgs]
      if self.fixed_size:
        max_img_len = self.img_segment
      else:
        max_img_len = max(img_lens)
      for i,img in enumerate(imgs):
        n = img_lens[i]
        imgs[i] = np.pad(img, ((0,max_img_len-n),(0,0)), mode='wrap')

      audio_lens = [audio.shape[0] for audio in audios]
      if self.fixed_size:
        max_audio_len = self.audio_segment
      else:
        max_audio_len = max(audio_lens)
      for i,audio in enumerate(audios):
        n = audio_lens[i]
        audios[i] = np.pad(audio, ((0,max_audio_len-n),(0,0)), mode='wrap')

        
#       audio_mean = np.repeat(np.expand_dims(audio_mean, 1), max_audio_len, 1)
#       img_mean = np.repeat(np.expand_dims(img_mean, 1), max_img_len, 1)
        
      audios = np.stack(audios)
#       audios = np.concatenate([audios, audio_mean], -1)
      imgs = np.stack(imgs)
#       imgs = np.concatenate([imgs, img_mean], -1)

        
      res = audios, audio_lens, \
        np.concatenate(texts), \
        imgs, img_lens, np.array(ys_int).astype('float32')
      return res

    def val_collate_fn(self, batch):
      entry = batch[0]
      audio = entry[0]
      text = entry[1]
      img = entry[2]
      ys_int = np.array([entry[3]]).astype('float32')

      audios, imgs = [], []
      img_lens, audio_lens = [], []
      audio_n = audio.shape[0]
      img_n = img.shape[0]
        
#       audio_mean = audio.mean(0)
#       img_mean = img.mean(0)


      i = 0
      while True:
        audio_offset = i*self.audio_step*100
        img_offset = i*self.img_step*100

        if audio_offset >= audio_n and img_offset >= img_n:
            break

        if audio_offset < audio_n:
          audio_segment = audio[audio_offset:audio_offset + self.audio_step*self.audio_segment:self.audio_step, :]
        else:
#           start_id = np.random.randint(0, max(1,audio_n - self.audio_step*self.audio_segment + 1))
          start_id = np.random.randint(0, audio_n)
          audio_segment = audio[start_id:start_id+self.audio_step*self.audio_segment:self.audio_step, :]

        if img_offset < img_n:
            img_segment = img[img_offset:img_offset + self.img_step*self.img_segment:self.img_step, :]
        else:
#           start_id = np.random.randint(0, max(img_n - self.img_step*self.img_segment + 1, 1))
          start_id = np.random.randint(0, img_n)
          img_segment = img[start_id:start_id+self.img_step*self.img_segment:self.img_step, :]
        
#         audio_segment = np.concatenate([audio, np.repeat(np.expand_dims(audio_mean, 0), audio.shape[0], 0)], -1)
#         img_segment = np.concatenate([img, np.repeat(np.expand_dims(img_mean, 0), img.shape[0], 0)], -1)
    
        img_lens.append(img_segment.shape[0])
        audio_lens.append(audio_segment.shape[0])
        
        
        if self.fixed_size:
          max_img_len = self.img_segment
          n = img_lens[-1]
          img_segment = np.pad(img_segment, ((0,max_img_len-n),(0,0)), mode='wrap')

        if self.fixed_size:
          max_audio_len = self.audio_segment
          n = audio_lens[-1]
          audio_segment = np.pad(audio_segment, ((0,max_audio_len-n),(0,0)), mode='wrap')
        
        audios.append(audio_segment)
        imgs.append(img_segment)

        i += 1


#       text_lens = [text.shape[1]]
      text_lens = None

      res = audios, audio_lens, text, imgs, img_lens, ys_int
      return res

    def test_collate_fn(self, batch):
      entry = batch[0]
      audio = entry[0]
      text = np.expand_dims(entry[1],0)
      img = entry[2]

      audios, imgs = [], []
      img_lens, audio_lens = [], []
      audio_n = audio.shape[0]
      img_n = img.shape[0]

      i = 0
      while True:
        audio_offset = i*self.audio_step//3
        img_offset = i*self.img_step//3

        if audio_offset >= audio_n and img_offset >= img_n:
            break

        if audio_offset < audio_n:
          audio_segment = audio[audio_offset:audio_offset + self.audio_step*self.audio_segment:self.audio_step, :]
        else:
#           start_id = np.random.randint(0, max(1,audio_n - self.audio_step*self.audio_segment + 1))
          start_id = np.random.randint(0, audio_n)
          audio_segment = audio[start_id:start_id+self.audio_step*self.audio_segment:self.audio_step, :]

        if img_offset < img_n:
            img_segment = img[img_offset:img_offset + self.img_step*self.img_segment:self.img_step, :]
        else:
#           start_id = np.random.randint(0, max(img_n - self.img_step*self.img_segment + 1, 1))
          start_id = np.random.randint(0, img_n)
          img_segment = img[start_id:start_id+self.img_step*self.img_segment:self.img_step, :]
        
#         audio_segment = np.concatenate([audio, np.repeat(np.expand_dims(audio_mean, 0), audio.shape[0], 0)], -1)
#         img_segment = np.concatenate([img, np.repeat(np.expand_dims(img_mean, 0), img.shape[0], 0)], -1)
    
        img_lens.append(img_segment.shape[0])
        audio_lens.append(audio_segment.shape[0])
        
        
        if self.fixed_size:
          max_img_len = self.img_segment
          n = img_lens[-1]
          img_segment = np.pad(img_segment, ((0,max_img_len-n),(0,0)), mode='wrap')

        if self.fixed_size:
          max_audio_len = self.audio_segment
          n = audio_lens[-1]
          audio_segment = np.pad(audio_segment, ((0,max_audio_len-n),(0,0)), mode='wrap')
        
        audios.append(audio_segment)
        imgs.append(img_segment)

        i += 1


      text_lens = [text.shape[1]]

      res = audios, audio_lens, text, text_lens, imgs, img_lens
      return res
collates = Collates()

In [None]:
class Collates():
    audio_segment = 30
    audio_step = 15
    img_segment = 30
    img_step = 15
    fixed_size = False
    
    def collate_fn(self, batch):
      audios, texts, imgs, ys_int = [list(x) for x in zip(*batch)]
    
#       audio_mean = np.array([audio.mean(0) for audio in audios])
#       img_mean = np.array([img.mean(0) for img in imgs])
    
      audio_len = []
      for i, audio in enumerate(audios):
        start_id = np.random.randint(0,max(1,audio.shape[0] - self.audio_step*self.audio_segment + 1))
        start_id = np.random.randint(0,audio.shape[0])
        audios[i] = audio[start_id:start_id + self.audio_step*self.audio_segment:self.audio_step]
        
      for i, img in enumerate(imgs):
#         start_id = np.random.randint(0, max(img.shape[0] - self.img_step*self.img_segment + 1, 1))
        start_id = np.random.randint(0, img.shape[0])
        imgs[i] = img[start_id:start_id + self.img_step*self.img_segment:self.img_step]

#       text_lens = [text.shape[0] for text in texts]
#       max_text_len = max(text_lens)
#       for i,text in enumerate(texts):
#         n = text_lens[i]
#         texts[i] = np.pad(text, ((0,max_text_len-n),(0,0)), mode='wrap')

      text_lens = None

      img_lens = [img.shape[0] for img in imgs]
      if self.fixed_size:
        max_img_len = self.img_segment
      else:
        max_img_len = max(img_lens)
      for i,img in enumerate(imgs):
        n = img_lens[i]
        imgs[i] = np.pad(img, ((0,max_img_len-n),(0,0)), mode='wrap')

      audio_lens = [audio.shape[0] for audio in audios]
      if self.fixed_size:
        max_audio_len = self.audio_segment
      else:
        max_audio_len = max(audio_lens)
      for i,audio in enumerate(audios):
        n = audio_lens[i]
        audios[i] = np.pad(audio, ((0,max_audio_len-n),(0,0)), mode='wrap')

        
#       audio_mean = np.repeat(np.expand_dims(audio_mean, 1), max_audio_len, 1)
#       img_mean = np.repeat(np.expand_dims(img_mean, 1), max_img_len, 1)
        
      audios = np.stack(audios)
#       audios = np.concatenate([audios, audio_mean], -1)
      imgs = np.stack(imgs)
#       imgs = np.concatenate([imgs, img_mean], -1)

        
      res = audios, audio_lens, \
        np.concatenate(texts), \
        imgs, img_lens, np.array(ys_int)
      return res

    def val_collate_fn(self, batch):
      entry = batch[0]
      audio = entry[0]
      text = entry[1]
      img = entry[2]
      ys_int = np.array([entry[3]])

      audios, imgs = [], []
      img_lens, audio_lens = [], []
      audio_n = audio.shape[0]
      img_n = img.shape[0]
        
#       audio_mean = audio.mean(0)
#       img_mean = img.mean(0)


      i = 0
      while True:
        audio_offset = i*self.audio_step
        img_offset = i*self.img_step

        if audio_offset >= audio_n and img_offset >= img_n:
            break

        if audio_offset < audio_n:
          audio_segment = audio[audio_offset:audio_offset + self.audio_step*self.audio_segment:self.audio_step, :]
        else:
#           start_id = np.random.randint(0, max(1,audio_n - self.audio_step*self.audio_segment + 1))
          start_id = np.random.randint(0, audio_n)
          audio_segment = audio[start_id:start_id+self.audio_step*self.audio_segment:self.audio_step, :]

        if img_offset < img_n:
            img_segment = img[img_offset:img_offset + self.img_step*self.img_segment:self.img_step, :]
        else:
#           start_id = np.random.randint(0, max(img_n - self.img_step*self.img_segment + 1, 1))
          start_id = np.random.randint(0, img_n)
          img_segment = img[start_id:start_id+self.img_step*self.img_segment:self.img_step, :]
        
#         audio_segment = np.concatenate([audio, np.repeat(np.expand_dims(audio_mean, 0), audio.shape[0], 0)], -1)
#         img_segment = np.concatenate([img, np.repeat(np.expand_dims(img_mean, 0), img.shape[0], 0)], -1)
    
        img_lens.append(img_segment.shape[0])
        audio_lens.append(audio_segment.shape[0])
        
        
        if self.fixed_size:
          max_img_len = self.img_segment
          n = img_lens[-1]
          img_segment = np.pad(img_segment, ((0,max_img_len-n),(0,0)), mode='wrap')

        if self.fixed_size:
          max_audio_len = self.audio_segment
          n = audio_lens[-1]
          audio_segment = np.pad(audio_segment, ((0,max_audio_len-n),(0,0)), mode='wrap')
        
        audios.append(audio_segment)
        imgs.append(img_segment)

        i += 1


#       text_lens = [text.shape[1]]
      text_lens = None

      res = audios, audio_lens, text, imgs, img_lens, ys_int
      return res

    def test_collate_fn(self, batch):
      entry = batch[0]
      audio = entry[0]
      text = np.expand_dims(entry[1],0)
      img = entry[2]

      audios, imgs = [], []
      img_lens, audio_lens = [], []
      audio_n = audio.shape[0]
      img_n = img.shape[0]

      i = 0
      while True:
        audio_offset = i*self.audio_step//3
        img_offset = i*self.img_step//3

        if audio_offset >= audio_n and img_offset >= img_n:
            break

        if audio_offset < audio_n:
          audio_segment = audio[audio_offset:audio_offset + self.audio_step*self.audio_segment:self.audio_step, :]
        else:
#           start_id = np.random.randint(0, max(1,audio_n - self.audio_step*self.audio_segment + 1))
          start_id = np.random.randint(0, audio_n)
          audio_segment = audio[start_id:start_id+self.audio_step*self.audio_segment:self.audio_step, :]

        if img_offset < img_n:
            img_segment = img[img_offset:img_offset + self.img_step*self.img_segment:self.img_step, :]
        else:
#           start_id = np.random.randint(0, max(img_n - self.img_step*self.img_segment + 1, 1))
          start_id = np.random.randint(0, img_n)
          img_segment = img[start_id:start_id+self.img_step*self.img_segment:self.img_step, :]
        
#         audio_segment = np.concatenate([audio, np.repeat(np.expand_dims(audio_mean, 0), audio.shape[0], 0)], -1)
#         img_segment = np.concatenate([img, np.repeat(np.expand_dims(img_mean, 0), img.shape[0], 0)], -1)
    
        img_lens.append(img_segment.shape[0])
        audio_lens.append(audio_segment.shape[0])
        
        
        if self.fixed_size:
          max_img_len = self.img_segment
          n = img_lens[-1]
          img_segment = np.pad(img_segment, ((0,max_img_len-n),(0,0)), mode='wrap')

        if self.fixed_size:
          max_audio_len = self.audio_segment
          n = audio_lens[-1]
          audio_segment = np.pad(audio_segment, ((0,max_audio_len-n),(0,0)), mode='wrap')
        
        audios.append(audio_segment)
        imgs.append(img_segment)

        i += 1


      text_lens = [text.shape[1]]

      res = audios, audio_lens, text, text_lens, imgs, img_lens
      return res
collates = Collates()

In [None]:
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = collates.val_collate_fn, num_workers=0)
batch = batchify(next(iter(val_dataloader))[:6], device)
[x.shape for x in batch if x is not None]

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True,
        collate_fn = collates.collate_fn, num_workers=0)
batch = next(iter(train_dataloader))
batch[2].shape

In [None]:
test1_path = '/usr/cs/public/mohd/data/test1'
test1_txt = '/usr/cs/public/mohd/test1_data.txt'
test1_dataset = MyDataset(test1_path, test1_txt, mode='test', size=-1)

In [None]:
for set_id in range(1,4):
    test_path = '/usr/cs/public/mohd/data/test%d'%set_id
    test_txt = '/usr/cs/public/mohd/test%d_data.txt'%set_id
    test_dataset = MyDataset(test_path, test_txt, mode='test', size=-1)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False,
            collate_fn = collates.test_collate_fn, num_workers=0)
    test_iter = iter(test_dataloader)
    from data import emotions


    model.eval()

    out_file = open('test%d.csv'%set_id, 'w')
    out_file.write("FileID,Emotion\n")

    for i in range(len(test_iter)):
        batch = next(test_iter, None)
        fn = test_dataset.files[i][:-2]

        pred = model(batchify(batch[:6], device)) 
        pred = pred.mean(0).view(1,-1)
        pred = torch.argmax(pred, 1)
        pred = emotions[pred]
        out_file.write("%s,%s\n"%(fn, pred))

    out_file.close()

In [None]:
np.mean([x.shape[0] for x in train_dataset.audios])

In [None]:
import model, importlib
importlib.reload(model)
from model import MyModel2
from torch import nn, optim
import matplotlib.pyplot as plt

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
        collate_fn = collates.collate_fn, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = collates.val_collate_fn, num_workers=0)
device = torch.device('cuda')
# device = torch.device('cpu')

model2 = MyModel2(2,2,200).to(device)
train_losses = []
val_losses = []
val_x = []

crit = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model2.parameters(), lr=1e-3, momentum=0.9, nesterov=True)
optimizer = optim.Adam(model2.parameters(), lr=1e-3)
# optimizer = optim.Adadelta(model2.parameters(), lr=1)
# decay = optim.lr_scheduler.ExponentialLR(optimizer, 0.995)
while True:
    model2.train()
    train_iter = iter(train_dataloader)
    batch = next(train_iter, None)
    if batch is None:
        train_iter = iter(train_dataloader)
        continue
    text = batch[2]
    text_len = batch[3]
    y = batch[8]
    y = torch.tensor(y).to(device)
    text = torch.tensor(text).to(device)
    text_len = torch.tensor(text_len).to(device)
    res = model2(text, text_len)
    loss = crit(res, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
    train_losses.append(loss.item())
    if model2.iter % 100 == 0:
        print('%7s %.2f %.2f'%('[Train]',loss.item(), acc))
        plt.plot(train_losses, label='train')
#         plt.title('train')   
    model2.iter += 1
#     decay.step()
    
    
    if model2.iter % 100 == 0:
        model2.eval()
        val_iter = iter(val_dataloader)
        losses, accs= [], []
        for i in range(128):
            batch = next(val_iter, None)
            text = batch[1]
            text = batch[2]
            text_len = batch[3]
            y = batch[8]
            y = torch.tensor(y).to(device)
            text = torch.tensor(text).to(device)
            text_len = torch.tensor(text_len).to(device)
            with torch.no_grad():
                res = model2(text, text_len)
                loss = crit(res, y)
            acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
            losses.append(loss.item())
            accs.append(acc)
        val_losses.append(mean(losses))
        val_x.append(model2.iter)
        print('%7s %.2f %.2f'%('[Val]', mean(losses), mean(accs)))
        plt.plot(val_x,val_losses, label='val')
    if model2.iter % 100 == 0:
        plt.legend()
        plt.show()

In [None]:
import model, importlib
importlib.reload(model)
from model import MyModel2
from torch import nn, optim
import matplotlib.pyplot as plt

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
        collate_fn = collates.collate_fn, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = collates.val_collate_fn, num_workers=0)
device = torch.device('cuda')
# device = torch.device('cpu')

model2 = MyModel2(2,2,400).to(device)
train_losses = []
val_losses = []
val_x = []

crit = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model2.parameters(), lr=1e-3, momentum=0.9, nesterov=True)
optimizer = optim.Adam(model2.parameters(), lr=1e-3)
# optimizer = optim.Adadelta(model2.parameters(), lr=1)
# decay = optim.lr_scheduler.ExponentialLR(optimizer, 0.995)
decay = optim.lr_scheduler.StepLR(optimizer, 1500)
print('starting')
while True:
    model2.train()
    train_iter = iter(train_dataloader)
    batch = next(train_iter, None)
    if batch is None:
        train_iter = iter(train_dataloader)
        continue
    text = batch[2]
    text_len = batch[3]
    y = batch[8]
    y = torch.tensor(y).to(device)
    text = torch.tensor(text).to(device)
    text_len = torch.tensor(text_len).to(device)
    res = model2(text, text_len)
    loss = crit(res, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
    train_losses.append(loss.item())
    if model2.iter % 100 == 0:
        print('%7s %.2f %.2f'%('[Train]',loss.item(), acc))
        plt.plot(train_losses, label='train')
#         plt.title('train')   
    model2.iter += 1
#     decay.step()
    
    
    if model2.iter % 100 == 0:
        model2.eval()
        val_iter = iter(val_dataloader)
        losses, accs= [], []
        for i in range(128):
            batch = next(val_iter, None)
            text = batch[1]
            text = batch[2]
            text_len = batch[3]
            y = batch[8]
            y = torch.tensor(y).to(device)
            text = torch.tensor(text).to(device)
            text_len = torch.tensor(text_len).to(device)
            with torch.no_grad():
                res = model2(text, text_len)
                loss = crit(res, y)
            acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
            losses.append(loss.item())
            accs.append(acc)
        val_losses.append(mean(losses))
        val_x.append(model2.iter)
        print('%7s %.2f %.2f'%('[Val]', mean(losses), mean(accs)))
        plt.plot(val_x,val_losses, label='val')
    if model2.iter % 100 == 0:
        plt.legend()
        plt.show()

In [None]:
import model, importlib
importlib.reload(model)
from model import MyModel2
from torch import nn, optim
import matplotlib.pyplot as plt

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True,
        collate_fn = collates.collate_fn, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = collates.val_collate_fn, num_workers=0)
device = torch.device('cuda')
# device = torch.device('cpu')

# model2 = MyModel2(2,2,400).to(device)
train_losses = []
val_losses = []
val_x = []

crit = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model2.parameters(), lr=1e-3, momentum=0.9, nesterov=True)
# optimizer = optim.Adam(model2.parameters(), lr=1e-3)
# decay = optim.lr_scheduler.ExponentialLR(optimizer, 0.995)
decay = optim.lr_scheduler.StepLR(optimizer, 15000)
print('starting')
while True:
    model2.train()
    train_iter = iter(train_dataloader)
    batch = next(train_iter, None)
    if batch is None:
        train_iter = iter(train_dataloader)
        continue
    text = batch[2]
    text_len = batch[3]
    y = batch[8]
    y = torch.tensor(y).to(device)
    text = torch.tensor(text).to(device)
    text_len = torch.tensor(text_len).to(device)
    res = model2(text, text_len)
    loss = crit(res, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
    train_losses.append(loss.item())
    if model2.iter % 500 == 0:
        print('%7s %.2f %.2f'%('[Train]',loss.item(), acc))
        plt.plot(train_losses, label='train')
#         plt.title('train')   
    model2.iter += 1
    decay.step()
    
    
    if model2.iter % 500 == 0:
        model2.eval()
        val_iter = iter(val_dataloader)
        losses, accs= [], []
        for i in range(512):
            batch = next(val_iter, None)
            text = batch[1]
            text = batch[2]
            text_len = batch[3]
            y = batch[8]
            y = torch.tensor(y).to(device)
            text = torch.tensor(text).to(device)
            text_len = torch.tensor(text_len).to(device)
            with torch.no_grad():
                res = model2(text, text_len)
                loss = crit(res, y)
            acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
            losses.append(loss.item())
            accs.append(acc)
        val_losses.append(mean(losses))
        val_x.append(model2.iter)
        print('%7s %.2f %.2f'%('[Val]', mean(losses), mean(accs)))
        plt.plot(val_x,val_losses, label='val')
    if model2.iter % 500 == 0:
        plt.legend()
        plt.show()

In [None]:
from torch import nn, optim
import torch.nn.functional as F

collates.audio_segment = 50
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True,
        collate_fn = collates.collate_fn, num_workers=0)

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.iter = 0
#         self.rnn = nn.LSTM(80, 256, batch_first=True)
        self.cnn = nn.Sequential(
            nn.Conv1d(80,16,3),
            nn.ReLU(),
            nn.Conv1d(16,8,3),
            nn.ReLU()
        )
#         self.lin = nn.Sequential(
#             nn.Linear(256, 32),
#             nn.ReLU(),
#             nn.Linear(32,256),
#             nn.ReLU(),
#         )
        self.classify = nn.Linear(256, 7)
        
    def forward(self, x):
#         _, (h, _) = self.rnn(x)
#         h = torch.cat([x for x in h], -1)
#         h = self.lin(h)
        h = self.cnn(x)
        print(h.shape)
        return self.classify(h)
    
model = TestModel().to(device)
crit = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model2.parameters(), lr=1e-3, momentum=0.9, nesterov=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
while True:
    model.train()
    train_iter = iter(train_dataloader)
    batch = next(train_iter, None)
    if batch is None:
        train_iter = iter(train_dataloader)
        continue
    audio = batch[0]
    y = batch[7]
    y = torch.tensor(y).to(device)
    audio = torch.tensor(audio).to(device)
    res = model(audio)
    loss = crit(res, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
    if model.iter % 20 == 0:
        print('\r%7s %.2f %.2f'%('[Train]',loss.item(), acc), end='')
    model.iter += 1
    
#     if model2.iter % 50 == 0:
#         model2.eval()
#         val_iter = iter(val_dataloader)
#         losses, accs= [], []
#         for i in range(512):
#             batch = next(val_iter, None)
#             text = batch[1]
#             y = batch[7]
#             y = torch.tensor(y).to(device)
#             text = torch.tensor(text).to(device)
#             with torch.no_grad():
#                 res = model2(text)
#                 loss = crit(res, y)
#             acc = (torch.argmax(res,1) == y).cpu().numpy().mean()
#             losses.append(loss.item())
#             accs.append(acc)
#         print('\n%7s %.2f %.2f'%('[Val]', mean(losses), mean(accs)))

In [None]:
hs1 = [32,64,128,256,512]
hs2 = [32,64,128]
lrs = [1e-2,1e-3,1e-3,1e-4]
regs = [1e-4,1e-5,1e-6,1e-6]
segments = [5,10,15,20,25,30,35,40,45,50]
steps = [10,20,30,40,50]
BS = [64,128,256,512]

for i in range(1000):
    H = np.random.choice(hs1, 7)
    H2 = np.random.choice(hs2, 1)
    H[2] = H2[0]
    S = np.random.choice(segments, 2)
    ST = np.random.choice(steps, 2)
    bs = np.random.choice(BS)
#     reg = np.random.choice(regs, 4)
#     lr = np.random.choice(lrs, 4)
    
    
#     lr = [float(x) for x in lr]
#     reg = [float(x) for x in reg]
#     bs = int(bs)
    H = [int(x) for x in H]
    S = [int(x) for x in S]
    ST = [int(x) for x in ST]
    bs = int(bs)
#     BN = [int(x) for x in BN]
#     L = [int(x) for x in L]
    
        
    collates.audio_segment = S[0]
    collates.audio_step = ST[0]
    collates.img_segment = S[1]
    collates.img_step = ST[1]

    train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True,
            collate_fn = collates.collate_fn, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
            collate_fn = collates.val_collate_fn, num_workers=0)

    hparams = HParams(audio_h = H[0], img_h = H[1], text_h = H[2],
                      fc1_av = H[3], fc1_at = H[4], fc1_vt = H[5], fc2_dim = H[6],
                      epoch_size = len(train_dataloader))


    cur_datetime = datetime.now().strftime("%m-%d-%H:%M:%S")
    name = "H=[{}]-img=[{} {}]-audio=[{} {}]-bs=[{}]".format(" ".join(map(str,H)),
                                                     collates.img_segment, collates.img_step,
                                                     collates.audio_segment, collates.audio_step, bs)
    if name:
        name = "%s-%s"%(cur_datetime,name)
    else:
        name = cur_datetime

    ## Model init !!!
    model, writer = initiate_model(hparams, name)
    optimizers = initiate_optimizers(hparams, model)

    train_iter = iter(train_dataloader)
    data_counter = 0
    if model.hparams.setting == 'aux':
        earlystopper_audio = EarlyStopping()
        earlystopper_img = EarlyStopping()
        earlystopper_text = EarlyStopping()
        earlystopper_av = EarlyStopping()
        earlystopper_at = EarlyStopping()
        earlystopper_vt = EarlyStopping()
        earlystopper_int = EarlyStopping()
    elif model.hparams.setting == 'full':
        earlystopper_stage1 = EarlyStopping(patience=20)
        earlystopper_stage2 = EarlyStopping(patience=40)
        earlystopper_stage3 = EarlyStopping(patience=60)
    earlystopper_spk = EarlyStopping(patience=10)

    print("[Starting]", model.name)
    while True:
        model.train()
        batch = next(train_iter, None)
        if batch is None:
            train_iter = iter(train_dataloader)
            batch = next(train_iter, None)


    #     batch = list(batch)
    #     batch[0] += np.random.normal(0,0.3,batch[0].shape)
    #     batch[1] += np.random.normal(0,0.2,batch[1].shape)
    #     batch[3] += np.random.normal(0,0.3,batch[3].shape)


        batch = [torch.tensor(x).to(device) for x in batch]
        res = train_step(model, optimizers, batch, writer)
        if model.iter % (hparams.epoch_size//3) == 0:
            val_iter = iter(val_dataloader)
            model.eval()
            val_batch = 256
            results = []
            for i in range(val_batch):
                batch = next(val_iter, None)
                if batch is None:
                    val_iter = iter(val_dataloader)
                    batch = next(val_iter, None)
                ys_audio, ys_img, ys_int, spk, gen, age = [torch.tensor(x).to(device) for x in batch[6:]]

                preds = model(batchify(batch[:6], device)) 
                preds = [x.mean(0).view(1,-1) for x in preds]

                if model.hparams.setting == 'aux' and (model.stage == 1 or model.stage == 2):
                    if model.stage == 1:
                        with torch.no_grad():
                            loss_audio = model.loss(preds[0], ys_audio).item()
                            loss_img = model.loss(preds[1], ys_img).item()
                            loss_text = model.loss(preds[2], ys_int).item()
                            loss_spk = model.hparams.alpha_spk * model.loss(preds[3], spk).mean()
                            loss_gen = model.hparams.alpha_gen * model.loss(preds[4], gen).mean()
                            loss_age = model.hparams.alpha_age * model.reg_loss(preds[5], age).mean()
                        audio_acc = (torch.argmax(preds[0],1) == ys_audio).cpu().numpy().mean()
                        img_acc = (torch.argmax(preds[1],1) == ys_img).cpu().numpy().mean()
                        text_acc = (torch.argmax(preds[2],1) == ys_int).cpu().numpy().mean()

                        results.append((loss_audio, loss_img, loss_text,
                                       audio_acc, img_acc, text_acc,
                                       loss_spk, loss_gen, loss_age))
                    elif model.stage == 2:
                        with torch.no_grad():
                            loss_av = model.loss(preds[0], ys_int).item()
                            loss_at = model.loss(preds[1], ys_int).item()
                            loss_vt = model.loss(preds[2], ys_int).item()
                            loss_spk = model.hparams.alpha_spk * model.loss(preds[3], spk).mean()
                            loss_gen = model.hparams.alpha_gen * model.loss(preds[4], gen).mean()
                            loss_age = model.hparams.alpha_age * model.reg_loss(preds[5], age).mean()

                        av_acc = (torch.argmax(preds[0],1) == ys_audio).cpu().numpy().mean()
                        at_acc = (torch.argmax(preds[1],1) == ys_img).cpu().numpy().mean()
                        vt_acc = (torch.argmax(preds[2],1) == ys_int).cpu().numpy().mean()

                        results.append((loss_av, loss_at, loss_vt,
                                       av_acc, at_acc, vt_acc,
                                       loss_spk, loss_gen, loss_age))
                else:
                    with torch.no_grad():
                        loss_int = model.loss(preds[0], ys_int).item()
                        loss_spk = model.hparams.alpha_spk * model.loss(preds[1], spk).mean()
                        loss_gen = model.hparams.alpha_gen * model.loss(preds[2], gen).mean()
                        loss_age = model.hparams.alpha_age * model.reg_loss(preds[3], age).mean()

                    int_acc = (torch.argmax(preds[0],1) == ys_int).cpu().numpy().mean()
                    results.append((loss_int, int_acc, loss_spk, loss_gen, loss_age))

            results = [mean([res[i] for res in results]) for i in range(len(results[0]))]

            if model.hparams.setting == 'aux':
                if model.stage == 1:
                    loss_audio, loss_img, loss_text, audio_acc, img_acc, text_acc, loss_spk, loss_gen, loss_age = results

                    if not model.earlystop_audio and earlystopper_audio(loss_audio):
                        model.earlystop_audio = True
                        print("[%d] Stopped audio"%model.iter)
                    if not model.earlystop_img and earlystopper_img(loss_img):
                        model.earlystop_img = True
                        print("[%d] Stopped img"%model.iter)
                    if not model.earlystop_text and earlystopper_text(loss_text):
                        model.earlystop_text = True
                        print("[%d] Stopped text"%model.iter)
                    if earlystopper_spk(loss_spk):
                        model.earlystop_spk = True

                    if not model.earlystop_audio:
                        writer.add_scalar('val/stage1/loss/audio_loss', loss_audio, model.iter)
                        writer.add_scalar('val/stage1/acc/audio_acc',audio_acc, model.iter)
                    if not model.earlystop_img:
                        writer.add_scalar('val/stage1/loss/img_loss', loss_img, model.iter)
                        writer.add_scalar('val/stage1/acc/img_acc',img_acc, model.iter)
                    if not model.earlystop_text:
                        writer.add_scalar('val/stage1/loss/text_loss', loss_text, model.iter)
                        writer.add_scalar('val/stage1/acc/text_acc',text_acc, model.iter)

                    if model.earlystop_audio and model.earlystop_img and model.earlystop_text:
                        model.stage = 2
                        model.earlystop_spk = False
                        earlystopper_spk.counter = 0
                elif model.stage == 2:
                    loss_av, loss_at, loss_vt, av_acc, at_acc, vt_acc, loss_spk, loss_gen, loss_age = results

                    if not model.earlystop_av and earlystopper_av(loss_av):
                        model.earlystop_av = True
                        print("[%d] Stopped av"%model.iter)
                    if not model.earlystop_at and earlystopper_at(loss_at):
                        model.earlystop_at = True
                        print("[%d] Stopped at"%model.iter)
                    if not model.earlystop_vt and earlystopper_vt(loss_vt):
                        model.earlystop_vt = True
                        print("[%d] Stopped vt"%model.iter)
                    if earlystopper_spk(loss_spk):
                        model.earlystop_spk = True

                    if not model.earlystop_av:
                        writer.add_scalar('val/stage2/loss/av_loss', loss_av, model.iter)
                        writer.add_scalar('val/stage2/acc/av_acc',av_acc, model.iter)
                    if not model.earlystop_at:
                        writer.add_scalar('val/stage2/loss/at_loss', loss_at, model.iter)
                        writer.add_scalar('val/stage2/acc/at_acc',at_acc, model.iter)
                    if not model.earlystop_vt:
                        writer.add_scalar('val/stage2/loss/vt_loss', loss_vt, model.iter)
                        writer.add_scalar('val/stage2/acc/vt_acc',vt_acc, model.iter)

                    if model.earlystop_av and model.earlystop_at and model.earlystop_vt:
                        model.stage = 3
                elif model.stage == 3:
                    loss_int, int_acc, loss_spk, loss_gen, loss_age = results

                    writer.add_scalar('val/stage3/loss/int_loss', loss_int, model.iter)
                    writer.add_scalar('val/stage3/acc/int_acc',int_acc, model.iter)


                    if earlystopper_int(loss_int):
                        print("[%d] Stopped int"%model.iter)
                        break

            elif model.hparams.setting == 'full':
                loss_int, int_acc, loss_spk, loss_gen, loss_age = results
                if model.stage == 1:
                    if earlystopper_spk(loss_spk):
                        model.earlystop_spk = True
                    if earlystopper_stage1(loss_int):
                        model.stage = 2
                        print("[%d] Finished stage 1"%model.iter)
                        model.earlystop_spk = False
                        earlystopper_spk.counter = 0
                elif model.stage == 2:
                    if earlystopper_spk(loss_spk):
                        model.earlystop_spk = True
                    if earlystopper_stage2(loss_int):
                        model.stage = 3
                        print("[%d] Finished stage 2"%model.iter)
                elif model.stage == 3:
                    if earlystopper_spk(loss_spk):
                        model.earlystop_spk = True
                    if earlystopper_stage3(loss_int):
                        print("[%d] Finished stage 3"%model.iter)
                        break
                writer.add_scalar('val/loss/int_loss', loss_int, model.iter)
                writer.add_scalar('val/acc/int_acc',int_acc, model.iter)

            if not model.earlystop_spk:
                writer.add_scalar('val/loss/spk/spk', loss_spk, model.iter)
                writer.add_scalar('val/loss/spk/gen', loss_gen, model.iter)
                writer.add_scalar('val/loss/spk/age', loss_age, model.iter)

In [None]:
from sklearn.manifold import TSNE
from data import emotions
import matplotlib.pyplot as plt
tsne = TSNE(n_components=2, verbose=0, perplexity=30, n_iter=1000)

In [None]:
n = 100
for batch in train_dataloader:
    imgs = batch[3][:n]
    y = [emotions[i] for i in batch[-2]][:n]
    imgs = imgs.reshape(-1,2048)
    break

In [None]:
x = tsne.fit_transform(imgs)

In [None]:
x_plot = {emo: [] for emo in emotions}
offset = 0
for label in y:
    x_plot[label].append(x[offset:offset+16].mean(0))
    offset += 16
for label in y:
    x_plot[label] = np.array(x_plot[label])

In [None]:
x_plot

In [None]:
for label in emotions:
    plt.scatter(x_plot[label][:,0], x_plot[label][:,1], label=label)
plt.legend()
plt.show()

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
from time import time

In [None]:
paths = glob('/usr/cs/public/data/train/*_imgs')
paths.sort()
path = paths[np.random.randint(9600)]
path = paths[9599]
print(path)
for i,fn in enumerate(os.listdir(path)):
    if i % 20 == 0:
        im = Image.open(os.path.join(path,fn))
        plt.imshow(im)
        plt.show()

In [None]:
im.size

In [None]:
paths = glob('/usr/cs/public/data/train/*_imgs')
paths.sort()
path = paths[np.random.randint(7500)]
t0 = time()
ims = []
for i,fn in enumerate(os.listdir(path)):
    im = Image.open(os.path.join(path,fn))
    ims.append(np.array(im))
imd = np.stack(ims)
print(time() - t0)

In [None]:
def collate_fn(batch):
  audios, texts, imgs, ys_audio, ys_img, ys_int = [list(x) for x in zip(*batch)]

  text_lens = [text.shape[0] for text in texts]
  max_text_len = max(text_lens)
  for i,text in enumerate(texts):
    n = text_lens[i]
    texts[i] = np.pad(text, ((0,max_text_len-n),(0,0)), mode='constant')
    
  img_lens = [img.shape[0] for img in imgs]
  max_img_len = max(img_lens)
  for i,img in enumerate(imgs):
    n = img_lens[i]
    imgs[i] = np.pad(img, ((0,max_img_len-n),(0,0)), mode='constant')


  res = np.stack(audios).transpose(0,2,1), np.stack(texts), text_lens, np.stack(imgs), img_lens, \
    np.array(ys_audio), np.array(ys_img), np.array(ys_int)
  return res

def val_collate_fn(batch):
  entry = batch[0]
  audio = np.expand_dims(entry[0],0).transpose(0,2,1)
  text = np.expand_dims(entry[1],0)
  img = np.expand_dims(entry[2],0)
  ys_audio = np.array([entry[3]])
  ys_img = np.array([entry[4]])
  ys_int = np.array([entry[5]])

  audios, imgs = [], []
  img_lens = []
  audio_n = audio.shape[1]
  img_n = img.shape[1]
  audio_segments = ceil(audio_n/AUDIO_SEGMENT)
  img_segments = ceil(img_n/IMG_SEGMENT)

  for i in range(max(audio_segments, img_segments)):
    audio_offset = i*AUDIO_SEGMENT
    img_offset = i*IMG_SEGMENT
    if audio_offset < audio_n:
      audios.append(audio[:, audio_offset:audio_offset + AUDIO_SEGMENT, :])
    else:
      start_id = np.random.randint(0, audio_n - AUDIO_SEGMENT + 1)
      audios.append(audio[:, start_id:start_id+AUDIO_SEGMENT, :])
    if img_offset < img_n:
        imgs.append(img[:, img_offset:img_offset + IMG_SEGMENT, :])
    else:
      start_id = np.random.randint(0, max(img_n - IMG_SEGMENT + 1, 1))
      imgs.append(img[:, start_id:start_id+IMG_SEGMENT, :])
    img_lens.append(imgs[-1].shape[1])
    
  text_lens = [text.shape[1]]

  res = audios, text, text_lens, imgs, img_lens, ys_audio, ys_img, ys_int
  return res

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True,
        collate_fn = collate_fn, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True,
        collate_fn = val_collate_fn, num_workers=0)

In [None]:
from sklearn.metrics import confusion_matrix
from data import emotions
import seaborn as sns
import matplotlib.pyplot as plt

val_iter = iter(val_dataloader)
model.eval()
val_batch = 10
predictions = [[], [], []]
trues = [[], [], []]
for i in range(val_batch):
    batch = next(val_iter, None)
    if batch is None:
        val_iter = iter(val_dataloader)
        batch = next(val_iter, None)
    audios, text, text_lens, imgs, img_lens, ys_audio, ys_img, ys_int = batch    
    trues[0].extend(ys_img.tolist())
    text, text_lens, ys_audio, ys_img, ys_int = [torch.tensor(x).to(device) \
                                                 for x in (text, text_lens, ys_audio, ys_img, ys_int)]
    preds = [[] for i in range(3)]
    print("True:", ys_img.item())
    subpreds = []
    confs = []
    for audio, img, img_len in zip(audios, imgs, img_lens):
        audio = torch.tensor(audio).to(device)
        img = torch.tensor(img).to(device)
        img_len = torch.tensor([img_len]).to(device)
        batch = audio, text, text_lens, img, img_len
        with torch.no_grad():
            pred = model(batch)
        x = torch.argmax(pred[1],1).item()
        subpreds.append(x)
        confs.append(pred[1][0][x].item())
        for i in range(len(preds)):
            preds[i].append(pred[i].cpu().numpy())
    print("Pred:", subpreds)
    print("Confs", ["%.2f"%x for x in confs])
    preds_avg = [combine_avg(ps) for ps in preds]
    preds = [torch.tensor(l).to(device) for l in preds_avg]
#     int_c = torch.argmax(preds[0],1)
    img_c = torch.argmax(preds[1],1)
    predictions[0].extend(img_c.cpu().numpy().tolist())    

print(emotions)
cm = confusion_matrix(trues[0], predictions[0], labels=list(range(7)))
print(cm)
sns.heatmap(cm,
            xticklabels = emotions,
            yticklabels = emotions)
plt.show()

In [None]:
import matplotlib.pyplot as plt

for batch in train_dataloader:
    for x in batch[0]:
        plt.imshow(x.T, origin='lower')
        plt.show()
    break
