# Lamb optimizer

In [None]:
"""Lamb optimizer."""

import collections
import math

import torch
from tensorboardX import SummaryWriter
from torch.optim import Optimizer


def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
    """Log a histogram of trust ratio scalars in across layers."""
    results = collections.defaultdict(list)
    for group in optimizer.param_groups:
        for p in group['params']:
            state = optimizer.state[p]
            for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
                if i in state:
                    results[i].append(state[i])

    for k, v in results.items():
        event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)

class Lamb(Optimizer):
    r"""Implements Lamb algorithm.

    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        adam (bool, optional): always use trust ratio = 1, which turns this into
            Adam. Useful for comparison purposes.

    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
                 weight_decay=0, adam=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        self.adam = adam
        super(Lamb, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                # m_t
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                # v_t
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                # Paper v3 does not use debiasing.
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']
                # Apply bias to lr to avoid broadcast.
                step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1

                weight_norm = p.data.pow(2).sum().sqrt()#.clamp(0, 10)

                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
                if group['weight_decay'] != 0:
                    adam_step.add_(group['weight_decay'], p.data)

                adam_norm = adam_step.pow(2).sum().sqrt()
                trust_ratio = weight_norm / adam_norm.add(1e-6) #trying to fix NaN (assuming it is caused by div0 here)
                state['weight_norm'] = weight_norm
                state['adam_norm'] = adam_norm
                state['trust_ratio'] = trust_ratio
                if self.adam:
                    trust_ratio = 1

                p.data.add_(-step_size * trust_ratio, adam_step)

        return loss


# preprocess

In [None]:
import numpy as np
import torch
import os
import sys
import re
import math
from torch.utils.data import Dataset, DataLoader

dataset_name = 'neutral'#'interpolation'
if len(sys.argv) < 2:
    print("Missing data path!")
    exit()

datapath = os.path.join(sys.argv[1], dataset_name)
datapath_preprocessed = os.path.join(sys.argv[1], dataset_name+'_preprocessed')

os.mkdir(datapath_preprocessed)

all_data = os.listdir(datapath)

for filename in all_data:
    with np.load(os.path.join(datapath, filename)) as data:
        image = data['image'].astype(np.uint8).reshape(16, 160, 160)[:,::2,::2]
        target = data['target']
        np.savez_compressed(os.path.join(datapath_preprocessed, filename), image=image, target=target)

print('Done')


# train_mixed_precision_distributed

In [2]:
import numpy as np
import torch
import os
import sys
import re
import math
from torch.utils.data import Dataset, DataLoader
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
# from lamb import Lamb

#tensorboard for accuracy graphs
import tensorflow as tf

def getCombinations(inputTensor, N, c, d):#input shape=(batch_size, obj_count, obj_dim) #batch_size=N, obj_count=c, obj_dim=d
    tensorA = inputTensor.reshape(N, 1, c, d).expand(N, c, c, d)
    tensorB = tensorA.transpose(1, 2)

    return torch.cat((tensorB, tensorA), 3)

dataset_name = 'neutral'#'interpolation'#'extrapolation'


if len(sys.argv) < 2:
    print("Missing data path!")
    exit()

datapath_preprocessed = os.path.join(sys.argv[1], dataset_name + '_preprocessed')

class PgmDataset(Dataset):
    def __init__(self, filenames):
        'Initialization'
        self.filenames = filenames

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]
        with np.load(os.path.join(datapath_preprocessed, filename)) as data:
            image = data['image'].astype(np.uint8).reshape(16, 80, 80)
            target = data['target']
        return image, target

class WReN(torch.nn.Module):
    def __init__(self, m):
        super(WReN, self).__init__()
        self.relation_network_depth = m

        self.g_dim = 512
        self.h_dim = 256
        self.f_dim = 256

        self.use_mag_enc = True #switch between scalar input and magnitude encoded input
        self.mag_enc_type_relu = False #switch between gaussian magnitude encoding and relu based magnitude encoding

        self.magnitude_encoding_dim = 20
        #model
        #magnitude encoding
        self.input_scale = 2.0/255.0
        self.input_offset = -1.0
        std_dev = 0.28
        self.input_encoding_variance_inv = 1.0 / (math.sqrt(2.0) * std_dev)
        self.normalization_factor = 1.0 / (math.sqrt(2*math.pi) * std_dev)
        self.mag_scale = torch.nn.Parameter(torch.linspace(-1.0, 1.0, steps=self.magnitude_encoding_dim), requires_grad=False)

        if self.use_mag_enc:
            conv_input_dim = self.magnitude_encoding_dim
        else:
            conv_input_dim = 1

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(conv_input_dim, 32, 3, stride=2), 
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 3, stride=2), 
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 3, stride=2), 
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 3, stride=2), 
            torch.nn.LeakyReLU()
        )
        self.post_cnn_linear = torch.nn.Linear(32*4*4, 256-9)

        self.tag_matrix = torch.nn.Parameter(torch.eye(9).repeat(8, 1), requires_grad=False)

        self.g = torch.nn.Sequential(
                torch.nn.Linear(2*256, self.g_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.g_dim, self.g_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.g_dim, self.g_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.g_dim, self.h_dim),
                torch.nn.LeakyReLU()
            )

        h = []
        for i in range(m):
            rel_layer_func = torch.nn.Sequential(
                torch.nn.Linear(2*self.h_dim, self.h_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.h_dim, self.h_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.h_dim, self.h_dim), 
                torch.nn.LeakyReLU()
            )
            h.append(rel_layer_func)

        self.h = torch.nn.ModuleList(h)

        f_in_dim = self.h_dim
        self.f = torch.nn.Sequential(
                torch.nn.Linear(f_in_dim, self.f_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.f_dim, self.f_dim), 
                torch.nn.LeakyReLU()
            )

        self.f_final = torch.nn.Linear(self.f_dim, 1)
        

    def forward(self, batch):
        batch_size = batch.size()[0]
        #Panel preprocessor CNN
        batch_flat = batch.reshape(batch_size*16, 1, 80, 80)

        if self.use_mag_enc:
            with torch.no_grad():
                #magnitude encoding
                batch_flat = batch_flat.transpose(1, 3)
                if self.mag_enc_type_relu:
                    #first order
                    batch_flat = batch_flat.add_(255/self.magnitude_encoding_dim)
                    batch_flat = torch.nn.functional.relu_(batch_flat.mul_(self.input_scale).add_(self.input_offset).add(-self.mag_scale))
                    #second order
                    batch_flat = torch.cat((batch_flat[:, :, :, :-1] - 2*batch_flat[:, :, :, 1:], batch_flat[:, :, :, -1].unsqueeze(dim=-1)), dim=-1).mul_(self.magnitude_encoding_dim/2)
                    batch_flat = torch.nn.functional.relu_(batch_flat)
                else:
                    batch_flat = batch_flat.mul_(self.input_scale).add_(self.input_offset).tanh_().add(self.mag_scale).mul_(self.input_encoding_variance_inv).pow_(2).mul_(-1).exp_().mul_(self.normalization_factor)
                batch_flat = batch_flat.transpose(3, 1)

        conv_out = self.conv(batch_flat)
        #scatter context
        objectsWithoutPos = self.post_cnn_linear(conv_out.reshape(batch_size*16, -1))
        panel_vectors = objectsWithoutPos.reshape(batch_size, 16, 256-9)
        given, option1, option2, option3, option4, option5, option6, option7, option8 = panel_vectors.split((8, 1, 1, 1, 1, 1, 1, 1, 1), dim=1)
        optionsWithContext = torch.cat((
            given, option1, 
            given, option2, 
            given, option3, 
            given, option4, 
            given, option5, 
            given, option6, 
            given, option7, 
            given, option8
        ), 1)
        optionsWithoutPos = optionsWithContext.reshape(batch_size*8*9, 256-9)

        objects = torch.cat((optionsWithoutPos, self.tag_matrix.repeat(batch_size, 1)), dim=1).reshape(batch_size*8, 9, 256-9+9)

        #MLRN
        objPairs2D = getCombinations(objects, batch_size*8, 9, 256)
        objPairs = objPairs2D.reshape(batch_size*8*(9*9), 2*256)

        gResult = self.g(objPairs)#apply MLP

        prev_result = gResult
        prev_dim = self.h_dim
        prev_result_2d = prev_result.reshape(batch_size*8, 9, 9, prev_dim)
        sum_j = prev_result_2d.sum(dim=2)
        for i, h_layer in enumerate(self.h):
            residual = sum_j
            intermed_obj_pairs_2d = getCombinations(sum_j, batch_size*8, 9, prev_dim)
            intermed_obj_pairs = intermed_obj_pairs_2d.reshape(batch_size*8*(9*9), 2*prev_dim)
            prev_result = h_layer(intermed_obj_pairs)#apply MLP
            prev_dim = self.h_dim
            prev_result_2d = prev_result.reshape(batch_size*8, 9, 9, prev_dim)
            sum_j = prev_result_2d.sum(dim=2)

        hSum = sum_j.sum(dim=1)
        result = self.f_final(self.f(hSum))#pre-softmax scores for every possible answer

        answer = result.reshape(batch_size, 8)

        #attempt to stabilize training (avoiding inf value activations in last layers) 
        activation_loss = hSum.pow(2).mean() + result.pow(2).mean()

        return answer, activation_loss

def worker_fn(rank, world_size):
    setup(rank, world_size)

    weights_filename = "weights.pt"
    batch_size = 512
    epochs = 240
    warmup_epochs = 8
    use_mixed_precision = True

    batch_size = batch_size // world_size #batch size per worker

    #Data
    all_data = os.listdir(datapath_preprocessed)
    train_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_train_(\d+)\.npz$', p) is not None]
    val_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_val_(\d+)\.npz$', p) is not None]
    train_dataset = PgmDataset(train_filenames)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, pin_memory=False, sampler=train_sampler)#shuffle is done by the sampler
    val_dataloader = DataLoader(PgmDataset(val_filenames), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)

    #Model
    device_ids = [rank]

    model = WReN(2).to(device_ids[0])#3-layer MLRN

    if weights_filename is not None and os.path.isfile("./" + weights_filename):
        model.load_state_dict(torch.load(weights_filename, map_location='cpu'))
        print('Weights loaded')
        cold_start = False
    else:
        print('No weights found')
        cold_start = True

    #Loss and optimizer
    final_lr = 2e-3

    def add_module_params_with_decay(module, weight_decay, param_groups):#adds parameters with decay unless they are bias parameters, which shouldn't receive decay
        group_with_decay = []
        group_without_decay = []
        for name, param in module.named_parameters():
            if not param.requires_grad: continue
            if name == 'bias' or name.endswith('bias'):
                group_without_decay.append(param)
            else:
                group_with_decay.append(param)
        param_groups.append({"params": group_with_decay, "weight_decay": weight_decay})
        param_groups.append({"params": group_without_decay})

    optimizer_param_groups = [
    ]

    add_module_params_with_decay(model.conv, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.post_cnn_linear, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.g, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.h, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.f, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.f_final, 2e-1, optimizer_param_groups)

    optimizer = Lamb(optimizer_param_groups, lr=final_lr)

    base_model = model
    if use_mixed_precision:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1") #Mixed Precision

    lossFunc = torch.nn.CrossEntropyLoss()
    softmax = torch.nn.Softmax(dim=1)

    #Parallel distributed model
    device = device_ids[0]
    torch.cuda.set_device(device)
    parallel_model = torch.nn.parallel.DistributedDataParallel(model, device_ids)

    if rank == 0:
        #accuracy logging
        sess = tf.Session()
        train_acc_placeholder = tf.placeholder(tf.float32, shape=())
        train_acc_summary = tf.summary.scalar('training_acc', train_acc_placeholder)
        val_acc_placeholder = tf.placeholder(tf.float32, shape=())
        val_acc_summary = tf.summary.scalar('validation_acc', val_acc_placeholder)
        writer = tf.summary.FileWriter("log", sess.graph)

    #training loop
    acc = []
    global_step = 0
    for epoch in range(epochs): 
        train_sampler.set_epoch(epoch) 

        # Validation
        val_acc = []
        parallel_model.eval()
        with torch.no_grad():
            for i, (local_batch, local_labels) in enumerate(val_dataloader):
                local_batch, targets = local_batch.to(device), local_labels.to(device)

                #answer = model(local_batch.type(torch.float32))
                answer, _ = parallel_model(local_batch.type(torch.float32))

                #Calc accuracy
                answerSoftmax = softmax(answer)
                maxIndex = answerSoftmax.argmax(dim=1)

                correct = maxIndex.eq(targets)
                accuracy = correct.type(dtype=torch.float16).mean(dim=0)
                val_acc.append(accuracy)

                if i % 50 == 0 and rank == 0:
                    print("batch " + str(i))

        total_val_acc = sum(val_acc) / len(val_acc)
        print('Validation accuracy: ' + str(total_val_acc.item()))
        if rank == 0:
            summary = sess.run(val_acc_summary, feed_dict={val_acc_placeholder: total_val_acc.item()})
            writer.add_summary(summary, global_step=global_step)

        # Training
        parallel_model.train()
        for i, (local_batch, local_labels) in enumerate(train_dataloader):
            global_step = global_step + 1

            if cold_start and epoch < warmup_epochs:#linear scaling of the lr for warmup during the first few epochs
                lr = final_lr * global_step / (warmup_epochs*len(train_dataset) / (batch_size * world_size))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

            local_batch, targets = local_batch.to(device_ids[0]), local_labels.to(device_ids[0])

            optimizer.zero_grad()
            answer, activation_loss = parallel_model(local_batch.type(torch.float32))

            loss = lossFunc(answer, targets) + activation_loss * 2e-3

            #Calc accuracy
            answerSoftmax = softmax(answer)
            maxIndex = answerSoftmax.argmax(dim=1)

            correct = maxIndex.eq(targets)
            accuracy = correct.type(dtype=torch.float16).mean(dim=0)
            acc.append(accuracy)
            
            #Training step
            if use_mixed_precision:
                with amp.scale_loss(loss, optimizer) as scaled_loss: #Mixed precision
                    scaled_loss.backward()
            else:
                loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(parallel_model.parameters(), 1e1)

            optimizer.step()

            if i % 50 == 0 and rank == 0:
                print("epoch " + str(epoch) + " batch " + str(i))
                print("loss", loss)
                print("activation loss", activation_loss)
                print(grad_norm)

            #logging and saving weights
            if i % 1000 == 999:
                trainAcc = sum(acc) / len(acc)
                acc = []
                print('Training accuracy: ' + str(trainAcc.item()))
                if rank == 0:
                    if weights_filename is not None:
                        torch.save(base_model.state_dict(), weights_filename)
                        print('Weights saved')

                    summary = sess.run(train_acc_summary, feed_dict={train_acc_placeholder: trainAcc.item()})
                    writer.add_summary(summary, global_step=global_step)  

        if cold_start and weights_filename is not None and epoch % 10 == 0 and rank == 0:
            torch.save(base_model.state_dict(), weights_filename + "_cp" + str(epoch))
            print('Checkpoint saved')


    cleanup()

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)

    # Explicitly setting seed to make sure that models created in two processes
    # start from same random weights and biases.
    torch.manual_seed(42)

def cleanup():
    torch.distributed.destroy_process_group()

def run(world_size):
    torch.multiprocessing.spawn(worker_fn, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    run(4)#4 GPUs


ModuleNotFoundError: No module named 'apex'

# test.py

In [None]:
import numpy as np
import torch
import os, sys
import re
from torch.utils.data import Dataset, DataLoader
import math

#tensorboard for accuracy graphs
import tensorflow as tf

def getCombinations(inputTensor, N, c, d):#input shape=(batch_size, obj_count, obj_dim) #batch_size=N, obj_count=c, obj_dim=d
    tensorA = inputTensor.reshape(N, 1, c, d).expand(N, c, c, d)
    tensorB = tensorA.transpose(1, 2)

    return torch.cat((tensorB, tensorA), 3)

devices = (torch.device("cuda:0"), torch.device("cuda:1"), torch.device("cuda:2"), torch.device("cuda:3"))

if len(sys.argv) < 2:
    print("Missing data path!")
    exit()

dataset_name = 'neutral'#'interpolation'#'extrapolation'
datapath = os.path.join(sys.argv[1],dataset_name)

class PgmDataset(Dataset):
    def __init__(self, filenames):
        'Initialization'
        self.filenames = filenames

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]
        with np.load(os.path.join(datapath, filename)) as data:
            image = data['image'].astype(np.uint8).reshape(16, 160, 160)[:,::2,::2]
            target = data['target']
            meta = data['relation_structure']
        return image, target, meta

def custom_collate_fn(batch):
    images, targets, metas = zip(*batch)
    images = torch.stack([torch.from_numpy(b) for b in images], 0)
    targets = torch.stack([torch.from_numpy(b) for b in targets], 0)
    return images, targets, metas

batch_size = 32

all_data = os.listdir(datapath)
test_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_test_(\d+)\.npz$', p) is not None]
test_dataloader = DataLoader(PgmDataset(test_filenames), batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, collate_fn=custom_collate_fn)

class WReN(torch.nn.Module):
    def __init__(self, m):
        super(WReN, self).__init__()
        self.relation_network_depth = m

        self.g_dim = 512
        self.h_dim = 256
        self.f_dim = 256

        self.use_mag_enc = True #switch between scalar input and magnitude encoded input
        self.mag_enc_type_relu = False #switch between gaussian magnitude encoding and relu based magnitude encoding

        self.magnitude_encoding_dim = 20
        #model
        #magnitude encoding
        self.input_scale = 2.0/255.0
        self.input_offset = -1.0
        #self.input_encoding_variance_inv = torch.nn.Parameter(torch.tensor(self.magnitude_encoding_dim * 0.5))
        std_dev = 0.28
        self.input_encoding_variance_inv = 1.0 / (math.sqrt(2.0) * std_dev)
        #self.normalization_factor = torch.nn.Parameter(torch.tensor(1.0 / (math.sqrt(2*math.pi) * std_dev)))
        self.normalization_factor = 1.0 / (math.sqrt(2*math.pi) * std_dev)
        self.mag_scale = torch.nn.Parameter(torch.linspace(-1.0, 1.0, steps=self.magnitude_encoding_dim), requires_grad=False)

        if self.use_mag_enc:
            conv_input_dim = self.magnitude_encoding_dim
        else:
            conv_input_dim = 1

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(conv_input_dim, 32, 3, stride=2), 
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 3, stride=2), 
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 3, stride=2), 
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(32, 32, 3, stride=2), 
            torch.nn.LeakyReLU()
        )
        self.post_cnn_linear = torch.nn.Linear(32*4*4, 256-9)#input = 32 feature maps with 4x4 resolution

        self.tag_matrix = torch.nn.Parameter(torch.eye(9).repeat(8, 1), requires_grad=False)

        self.g = torch.nn.Sequential(
                torch.nn.Linear(2*256, self.g_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.g_dim, self.g_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.g_dim, self.g_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.g_dim, self.h_dim),
                torch.nn.LeakyReLU()
            )

        h = []
        for i in range(m):
            rel_layer_func = torch.nn.Sequential(
                torch.nn.Linear(2*self.h_dim, self.h_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.h_dim, self.h_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.h_dim, self.h_dim), 
                torch.nn.LeakyReLU()
            )
            h.append(rel_layer_func)

        self.h = torch.nn.ModuleList(h)

        f_in_dim = self.h_dim
        self.f = torch.nn.Sequential(
                torch.nn.Linear(f_in_dim, self.f_dim), 
                torch.nn.LeakyReLU(),
                torch.nn.Linear(self.f_dim, self.f_dim), 
                torch.nn.LeakyReLU()
            )

        self.f_final = torch.nn.Linear(self.f_dim, 1)
        

    def forward(self, batch):
        batch_size = batch.size()[0]
        #Panel preprocessor CNN
        batch_flat = batch.reshape(batch_size*16, 1, 80, 80)#16 images per sample: 8 for context + 8 answer options

        if self.use_mag_enc:
            with torch.no_grad():
                #magnitude encoding
                batch_flat = batch_flat.transpose(1, 3)
                if self.mag_enc_type_relu:
                    #first order
                    batch_flat = batch_flat.add_(255/self.magnitude_encoding_dim)
                    batch_flat = torch.nn.functional.relu_(batch_flat.mul_(self.input_scale).add_(self.input_offset).add(-self.mag_scale))
                    #second order
                    batch_flat = torch.cat((batch_flat[:, :, :, :-1] - 2*batch_flat[:, :, :, 1:], batch_flat[:, :, :, -1].unsqueeze(dim=-1)), dim=-1).mul_(self.magnitude_encoding_dim/2)
                    batch_flat = torch.nn.functional.relu_(batch_flat)
                else:
                    batch_flat = batch_flat.mul_(self.input_scale).add_(self.input_offset).tanh_().add(self.mag_scale).mul_(self.input_encoding_variance_inv).pow_(2).mul_(-1).exp_().mul_(self.normalization_factor)
                batch_flat = batch_flat.transpose(3, 1)

        conv_out = self.conv(batch_flat)
        #scatter context
        objectsWithoutPos = self.post_cnn_linear(conv_out.reshape(batch_size*16, -1))
        panel_vectors = objectsWithoutPos.reshape(batch_size, 16, 256-9)
        given, option1, option2, option3, option4, option5, option6, option7, option8 = panel_vectors.split((8, 1, 1, 1, 1, 1, 1, 1, 1), dim=1)
        optionsWithContext = torch.cat((
            given, option1, 
            given, option2, 
            given, option3, 
            given, option4, 
            given, option5, 
            given, option6, 
            given, option7, 
            given, option8
        ), 1)
        optionsWithoutPos = optionsWithContext.reshape(batch_size*8*9, 256-9)

        objects = torch.cat((optionsWithoutPos, self.tag_matrix.repeat(batch_size, 1)), dim=1).reshape(batch_size*8, 9, 256-9+9)#8 answers to score per sample, 9 images (8 from context + 1 from answer) per answer option

        #MLRN
        objPairs2D = getCombinations(objects, batch_size*8, 9, 256)
        objPairs = objPairs2D.reshape(batch_size*8*(9*9), 2*256)

        gResult = self.g(objPairs)#apply MLP

        prev_result = gResult
        prev_dim = self.h_dim
        prev_result_2d = prev_result.reshape(batch_size*8, 9, 9, prev_dim)
        sum_j = prev_result_2d.sum(dim=2)
        for i, h_layer in enumerate(self.h):
            intermed_obj_pairs_2d = getCombinations(sum_j, batch_size*8, 9, prev_dim)
            intermed_obj_pairs = intermed_obj_pairs_2d.reshape(batch_size*8*(9*9), 2*prev_dim)
            prev_result = h_layer(intermed_obj_pairs)#apply MLP
            prev_dim = self.h_dim
            prev_result_2d = prev_result.reshape(batch_size*8, 9, 9, prev_dim)
            sum_j = prev_result_2d.sum(dim=2)

        hSum = sum_j.sum(dim=1)
        result = self.f_final(self.f(hSum))#pre-softmax scores for every possible answer

        answer = result.reshape(batch_size, 8)#scores of the 8 possible answers for every sample

        activation_loss = hSum.pow(2).mean() + result.pow(2).mean()

        return answer, activation_loss

model = WReN(2).to(devices[0]) #3-layer MLRN

if os.path.isfile("./weights.pt"):
    model.load_state_dict(torch.load("weights.pt"))
    print('Weights loaded')
else:
    print('No weights found')
    exit()


softmax = torch.nn.Softmax(dim=1)

parallel_model = torch.nn.DataParallel(model, device_ids=devices)

model.eval()

test_acc = []
objTypes = {}
attrTypes = {}
relTypes = {}
single_rel_correct = 0
single_rel_total = 0
# Testing
with torch.no_grad():
    for i, (local_batch, local_labels, meta) in enumerate(test_dataloader):
        local_batch, targets = local_batch.to(devices[0]), local_labels.to(devices[0])

        answer, _ = parallel_model(local_batch.type(torch.float32))

        #Calc accuracy
        answerSoftmax = softmax(answer)
        maxIndex = answerSoftmax.argmax(dim=1)

        correct = maxIndex.eq(targets)
        accuracy = correct.type(dtype=torch.float32).mean(dim=0)
        test_acc.append(accuracy)

        for j, jCorrect in enumerate(correct):
            jCorrect = jCorrect.item()
            if len(meta[j]) == 1:
                single_rel_total += 1
                if jCorrect == 1:
                    single_rel_correct += 1
                objType = meta[j][0][0]
                if objType in objTypes:
                    objTypes[objType]['total'] += 1
                    if jCorrect == 1:
                        objTypes[objType]['correct'] += 1
                else:
                    objTypes[objType] = {'total': 1, 'correct': jCorrect}

                attrType = meta[j][0][1]
                if attrType in attrTypes:
                    attrTypes[attrType]['total'] += 1
                    if jCorrect == 1:
                        attrTypes[attrType]['correct'] += 1
                else:
                    attrTypes[attrType] = {'total': 1, 'correct': jCorrect}

                relType = meta[j][0][2]
                if relType in relTypes:
                    relTypes[relType]['total'] += 1
                    if jCorrect == 1:
                        relTypes[relType]['correct'] += 1
                else:
                    relTypes[relType] = {'total': 1, 'correct': jCorrect}

        if i % 50 == 0:
            print("batch " + str(i))

    for key in objTypes:
        print(str(key) + ' ' + str(100 * objTypes[key]['correct'] / objTypes[key]['total']))
    for key in attrTypes:
        print(str(key) + ' ' + str(100 * attrTypes[key]['correct'] / attrTypes[key]['total']))
    for key in relTypes:
        print(str(key) + ' ' + str(100 * relTypes[key]['correct'] / relTypes[key]['total']))
    print(str(objTypes))
    print(str(attrTypes))
    print(str(relTypes))

    print('All single relations accuracy: ' + str(100 * single_rel_correct / single_rel_total))

    total_test_acc = sum(test_acc) / len(test_acc)
    print(sum(test_acc))
    print(len(test_acc))
    print('Test accuracy: ' + str(total_test_acc.item()))
