In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import multiprocessing
import pickle
import sys
import argparse
import os
import time

import numpy as np
from numpy import savetxt

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.dataset import CLEVR, collate_data, transform, GQA
from src.edl_utils import *
from src.model import MACNetwork
# from model_gqa import MACNetwork


In [3]:
batch_size = 128
learning_rate = 1e-4
dim_dict = {'CLEVR': 512,
            'gqa': 2048}

In [4]:
print( torch.cuda.device_count() )
for i in range(torch.cuda.device_count()):
    print( torch.cuda.get_device_name(i) )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

1
Tesla K40m
cuda


In [5]:
# params updating using running average 
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)


        
def train(epoch, dataset_type, dataset_object):

    train_set = DataLoader(
        dataset_object, batch_size=batch_size, num_workers=1, collate_fn=collate_data
#         dataset_object, batch_size=batch_size, num_workers=multiprocessing.cpu_count(), collate_fn=collate_data
    )
    
    phase = 'train'  # swd
    uncertainty = True
    num_classes = 28 # 28 classes for CLEVR!

    dataset = iter(train_set)
    pbar = tqdm(dataset)
    moving_loss = 0

    net.train(True)
    for iter_id, (image, question, q_len, answer) in enumerate(pbar):
        image, question, answer = (
            image.to(device),
            question.to(device),
            answer.to(device),
        )

        net.zero_grad()
        
        # forward
        # track history if only in train
        with torch.set_grad_enabled(phase == "train"):
            if uncertainty:
                y = one_hot_embedding(answer, num_classes)
                y = y.to(device)
                outputs = net(image, question, q_len)
                _, preds = torch.max(outputs, 1)
                loss = criterion(
                    outputs, y.float(), epoch, num_classes, 10, device)  # 10 is the annealing step

                match = torch.reshape(torch.eq(
                    preds, answer).float(), (-1, 1))
                acc = torch.mean(match)
                evidence = relu_evidence(outputs)
                alpha = evidence + 1
                u = num_classes / torch.sum(alpha, dim=1, keepdim=True)

                total_evidence = torch.sum(evidence, 1, keepdim=True)
                mean_evidence = torch.mean(total_evidence)
                mean_evidence_succ = torch.sum(
                    torch.sum(evidence, 1, keepdim=True) * match) / torch.sum(match + 1e-20)
                mean_evidence_fail = torch.sum(
                    torch.sum(evidence, 1, keepdim=True) * (1 - match)) / (torch.sum(torch.abs(1 - match)) + 1e-20)

            else:
                outputs = net(image, question, q_len)
#                 _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, answer)

            if phase == "train":
                loss.backward()
                optimizer.step()

        correct = outputs.detach().argmax(1) == answer
        correct = torch.tensor(correct, dtype=torch.float32).sum() / batch_size

        if moving_loss == 0:
            moving_loss = correct
        else:
            moving_loss = (moving_loss * iter_id + correct) / (iter_id + 1)

        pbar.set_description('Epoch: {}; CurLoss: {:.8f}; CurAcc: {:.5f}; Tot_Acc: {:.5f}'.format(epoch + 1, loss.item(), correct, moving_loss))

        accumulate(net_running, net, decay)
    return


def valid(epoch, dataset_type, dataset_object):

    valid_set = DataLoader(
        dataset_object, batch_size=4*batch_size, num_workers=1, collate_fn=collate_data
#         dataset_object, batch_size=4*batch_size, num_workers=multiprocessing.cpu_count(), collate_fn=collate_data
    )
    dataset = iter(valid_set)
    
    uncertainty = True
    num_classes = 28
    
    net_running.train(False)
    correct_counts = 0
    total_counts = 0
    running_loss = 0.0
    batches_done = 0
    with torch.no_grad():
        pbar = tqdm(dataset)
        for image, question, q_len, answer in pbar:
            image, question, answer = (
                image.to(device),
                question.to(device),
                answer.to(device),
            )

#             output = net_running(image, question, q_len)
#             loss = criterion(output, answer)
            
            y = one_hot_embedding(answer, num_classes)
            y = y.to(device)
            outputs = net_running(image, question, q_len)
            _, preds = torch.max(outputs, 1)
            loss = criterion(
                outputs, y.float(), epoch, num_classes, 10, device)  # 10 is the annealing step
            
            correct = outputs.detach().argmax(1) == answer
            running_loss += loss.item()

            batches_done += 1
            for c in correct:
                if c:
                    correct_counts += 1
                total_counts += 1

            pbar.set_description('Epoch: {}; Loss: {:.5f}; Acc: {:.5f}'.format(epoch + 1, loss.item(), correct_counts / total_counts))


    val_acc = correct_counts / total_counts
    val_loss = running_loss / total_counts
    print('Validation Accuracy: {:.5f}'.format(val_acc))
    print('Validation Loss: {:.8f}'.format(val_loss))
    
#     dataset_object.close()
    return val_acc, val_loss

In [6]:
dataset_type = 'CLEVR'
# input
decay = 0.999
load_embd = False
out_name = 'try'
n_epoch = 25

out_directory = 'result/'+ out_name +'/'
if not os.path.exists(out_directory):
    os.makedirs(out_directory)
print('Saving result to: ', out_directory)

if not load_embd:
    with open(f'data/{dataset_type}_dic.pkl', 'rb') as f:
        dic = pickle.load(f)
    n_words = len(dic['word_dic']) + 1
    n_answers = len(dic['answer_dic'])
    print('Training word embeddings from scratch...')
else:
    # add codes for loading GLOVE, embd dimensions, and out dim
    print('Loading GLOVE word embeddings...')
    pass

Saving result to:  result/try/
Training word embeddings from scratch...


In [7]:
# loading dataset using hdf5 imposing minimal overhead
since = time.time()
if dataset_type == "CLEVR":
    train_object = CLEVR('data/CLEVR_v1.0', transform=transform)
    val_object = CLEVR('data/CLEVR_v1.0', 'val', transform=None)
else:
    train_object = GQA('data/gqa', transform=transform)
    val_object = GQA('data/gqa', 'val', transform=None)
print('Dataset loaded in %.2f seconds' %(time.time()-since) )


Dataset loaded in 1.97 seconds


In [8]:
net = MACNetwork(n_words, dim_dict[dataset_type], classes=n_answers, max_step=4).to(device)
net_running = MACNetwork(n_words, dim_dict[dataset_type], classes=n_answers, max_step=4).to(device)
accumulate(net_running, net, 0)

In [9]:
# edl training
use_uncertainty = True
loss_func = 'digamma'

import torch.optim as optim
if use_uncertainty:
    if loss_func == 'digamma':
        criterion = edl_digamma_loss
    elif loss_func == 'log':
        criterion = edl_log_loss
    elif loss_func == 'mse':
        criterion = edl_mse_loss
    else:
        parser.error("--uncertainty requires --mse, --log or --digamma.")
else:
    criterion = nn.CrossEntropyLoss()

# optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=0.005)   
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

In [10]:
# logging training information
with open(out_directory + 'log.txt', 'w') as outfile:
    outfile.write('==== Training details ====\n')
    outfile.write('---- Model structure ----\n')
    outfile.write('Loading GLOVE embedding:  %s.     dictionary dim: %d. \n' %(load_embd, n_words))
    outfile.write('Hidden dimension: %d.     Output dimension: %d.\n' %(dim_dict[dataset_type], n_answers))

    outfile.write('\n---- Training detials ----\n')
    outfile.write('Batch size:  %d.     RA_decay: %f\n' %(batch_size, decay))
    outfile.write('Learning rate: %f.     Epochs: %d\n' %(learning_rate, n_epoch))

In [None]:
learning_curve = np.zeros([0,4])
acc_best = 0

for epoch in range(n_epoch):
    train(epoch, dataset_type, train_object)
    train_acc, train_loss = valid(epoch, dataset_type, train_object)
    val_acc, val_loss = valid(epoch, dataset_type, val_object)

    # saving training result details.
    learning_curve = np.append(learning_curve, np.array([[train_acc,val_acc,train_loss,val_loss]]), axis = 0)
    savetxt(out_directory+'learn_curve.csv', learning_curve, delimiter=',')

    # saving trained models
    if val_acc > acc_best:
        with open(out_directory+'checkpoint.model', 'wb') as f:
#         with open('checkpoint/checkpoint_{}.model'.format(str(epoch + 1).zfill(2)), 'wb') as f:
            torch.save(net_running.state_dict(), f)
        print('Accuracy increased from %.4f to %.4f, saved to %s. '%(acc_best, val_acc, out_directory+'checkpoint.model'))
        acc_best = val_acc

print('The best validation accuracy: ', acc_best)



Epoch: 1; CurLoss: 2.87849569; CurAcc: 0.14844; Tot_Acc: 0.14844:   0%|          | 0/5469 [00:00<?, ?it/s][A
Epoch: 1; CurLoss: 2.87849569; CurAcc: 0.14844; Tot_Acc: 0.14844:   0%|          | 1/5469 [00:00<1:20:31,  1.13it/s][A
Epoch: 1; CurLoss: 2.76361370; CurAcc: 0.17188; Tot_Acc: 0.16016:   0%|          | 1/5469 [00:01<1:20:31,  1.13it/s][A
Epoch: 1; CurLoss: 2.76361370; CurAcc: 0.17188; Tot_Acc: 0.16016:   0%|          | 2/5469 [00:01<1:13:40,  1.24it/s][A
Epoch: 1; CurLoss: 2.87284565; CurAcc: 0.21875; Tot_Acc: 0.17969:   0%|          | 2/5469 [00:02<1:13:40,  1.24it/s][A
Epoch: 1; CurLoss: 2.87284565; CurAcc: 0.21875; Tot_Acc: 0.17969:   0%|          | 3/5469 [00:02<1:08:53,  1.32it/s][A
Epoch: 1; CurLoss: 2.91383147; CurAcc: 0.20312; Tot_Acc: 0.18555:   0%|          | 3/5469 [00:02<1:08:53,  1.32it/s][A
Epoch: 1; CurLoss: 2.91383147; CurAcc: 0.20312; Tot_Acc: 0.18555:   0%|          | 4/5469 [00:02<1:05:38,  1.39it/s][A
Epoch: 1; CurLoss: 2.85994911; CurAcc: 0.17969; 