In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
import os

dataset_path = '/kaggle/input/indonesian-traffic-violation-on-motorcycle/image_crop/image crop'
dataset_images_path = os.path.join(dataset_path, "image")

In [4]:
image_files = os.listdir(dataset_images_path)
num_image_files = len(image_files)
print(f"Jumlah gambar dalam dataset: {num_image_files}")

Jumlah gambar dalam dataset: 1127


In [5]:
import pandas as pd

annotation = pd.read_csv('/kaggle/input/caption-hawww/captionnewbanget.csv')
annotation.head()

Unnamed: 0,filename,caption,category
0,TA_00001,pengendara motor hitam baju abu tanpa pelat no...,tanpa pelat nomor
1,TA_00002,terdapat pengendara motor silver di paling kir...,melanggar marka
2,TA_00003,terdapat lima pengendara motor di depan mobil ...,melanggar marka
3,TA_00004,terdapat pengendara motor hitam baju putih sen...,melanggar marka
4,TA_00005,pengendara motor berboncengan di paling kanan ...,melanggar marka


In [6]:
import re

def get_preprocessed_caption(caption):    
    caption = re.sub(r'\s+', ' ', caption)
    caption = caption.strip()
    caption = re.sub(r'[,.]', '', caption)
    caption = "<start> " + caption.lower() + " <end>" 
    return caption

images_captions_dict = {}

for idx, row in annotation.iterrows():
    filename = row['filename']
    caption = row['caption']
    caption = get_preprocessed_caption(caption)
    images_captions_dict[filename] = caption


In [7]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

resnet_model = models.resnet101(pretrained=True)
resnet_model.eval()
resnet_model = torch.nn.Sequential(*(list(resnet152_model.children())[:-1])) 

def extract_features152(image_path):
    img = Image.open(image_path)
    img = transforms.Resize(224)(img)
    img = transforms.ToTensor()(img)
    img = img[:3, :, :]
    img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
    img = img.unsqueeze(0)

    with torch.no_grad():
        features = resnet_model(img)

    return features

image_path = '/kaggle/input/indonesian-traffic-violation-on-motorcycle/image_crop/image crop/image/TA_00001.png'
image_features = extract_features152(image_path)
print(image_features.shape)
print(image_features)

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 138MB/s]  


torch.Size([1, 2048, 1, 1])
tensor([[[[1.5437]],

         [[0.0656]],

         [[0.2514]],

         ...,

         [[0.1701]],

         [[0.2219]],

         [[0.2330]]]])


In [8]:
import os
import numpy as np

output_dir = '/kaggle/working/OutputCodeTA/image_feature152'
os.makedirs(output_dir, exist_ok=True)

for filename in os.listdir(dataset_images_path):
    if filename.endswith(".png"): 
        image_path = os.path.join(dataset_images_path, filename)
        features = extract_features152(image_path)
        features = features.squeeze().numpy()

        image_id = os.path.splitext(filename)[0]
        output_file = os.path.join(output_dir, f"{image_id}.npz")

        np.savez_compressed(output_file, feat=features)

In [9]:
from sklearn.model_selection import train_test_split

image_filenames = list(images_captions_dict.keys())
image_filenames_train, image_filenames_temp = train_test_split(image_filenames, test_size=0.2, random_state=3)
image_filenames_val, image_filenames_test = train_test_split(image_filenames_temp, test_size=0.5, random_state=3)

In [10]:
def remove_start_tokens(caption):
    start_token = "<start>"

    words = caption.split()
  
    if words[0] == start_token:
        words = words[1:]
    
    cleaned_caption = ' '.join(words)
    
    return cleaned_caption

In [11]:
def remove_start_and_end_tokens(caption):
    start_token = "<start>"
    end_token = "<end>"

    words = caption.split()

    if words[0] == start_token:
        words = words[1:]

    if words[-1] == end_token:
        words = words[:-1]

    cleaned_caption = ' '.join(words)

    return cleaned_caption

In [12]:
print("Training:", len(image_filenames_train))
print("Validation:", len(image_filenames_val))
print("Test:", len(image_filenames_test))

Training: 901
Validation: 113
Test: 113


In [13]:
label_train = []

for image_filename in image_filenames_train:
    caption = images_captions_dict[image_filename]
    label_train.append(caption)

In [14]:
kata_unik = set()

for teks in label_train:
    tokens = teks.split()  
    kata_unik.update(tokens)

print("Jumlah kata unik dalam dataset pelatihan:", len(kata_unik))

Jumlah kata unik dalam dataset pelatihan: 96


In [15]:
from collections import Counter
import os

word_freq = Counter()

for teks in label_train:
    tokens = teks.split()  
    word_freq.update(tokens)

top_words = [word for word, _ in word_freq.most_common(96)]

output_file = os.path.join('/kaggle/working/OutputCodeTA', 'data_vocabulary.txt')
with open(output_file, "w") as file:
    file.write('<unk>' + "\n")
    for word in top_words:
        file.write(word + "\n")

print("Jumlah kata unik yang disimpan dalam file teks:", len(top_words))

Jumlah kata unik yang disimpan dalam file teks: 96


In [16]:
import tensorflow as tf

top_k = 96
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k,
                                                  oov_token="<unk>",
                                                  filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ')

tokenizer.fit_on_texts(label_train)
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'
label_train = tokenizer.texts_to_sequences(label_train)
label_train = tf.keras.preprocessing.sequence.pad_sequences(label_train, padding='post')

2024-07-10 14:40:33.649424: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-10 14:40:33.649539: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-10 14:40:33.775140: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [17]:
max_caption_length = max(len(t) for t in label_train)
print(max_caption_length)

19


In [18]:
index_word = tokenizer.index_word

index = 97 
word = index_word[index]
print(f"Kata dari indeks {index} adalah: '{word}'")

Kata dari indeks 97 adalah: 'dekat'


In [18]:
[tokenizer.index_word[i] for i in label_train[23]]

['<start>',
 'lima',
 'pengendara',
 'motor',
 'di',
 'depan',
 'mobil',
 'abu',
 'berhenti',
 'melewati',
 'garis',
 'marka',
 '<end>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<pad>']

In [41]:
label_train[23]

array([ 3, 75,  2,  5,  6, 16, 17, 59, 12, 14,  8,  7,  4,  0,  0,  0,  0,
        0,  0], dtype=int32)

In [19]:
urutan_input_dict = {}

for image_filename in image_filenames_train:
    captions = []
    image = image_filename.split('.')[0]
    caption = images_captions_dict[image_filename]
    captions.append(caption)
    
    caption_input = tokenizer.texts_to_sequences(captions)

    caption_input = tf.keras.preprocessing.sequence.pad_sequences(caption_input, padding='post', maxlen=max_caption_length)
        
    if image_filename not in urutan_input_dict:
        urutan_input_dict[image] = caption_input

In [20]:
import pickle

output_file = os.path.join('/kaggle/working/OutputCodeTA', 'urutan_input_dict.pkl')
with open(output_file, 'wb') as f:
    pickle.dump(urutan_input_dict, f)

In [21]:
urutan_target_dict = {}

for image_filename in image_filenames_train:
    captions = []
    image = image_filename.split('.')[0]
    caption = images_captions_dict[image_filename]

    caption = remove_start_and_end_tokens(caption)
    captions.append(caption)
    
    caption_target = tokenizer.texts_to_sequences(captions)

    caption_target = tf.keras.preprocessing.sequence.pad_sequences(caption_target, padding='post', maxlen=max_caption_length)
    
    if image_filename not in urutan_target_dict:
        urutan_target_dict[image] = caption_target

In [22]:
import pickle

output_file = os.path.join('/kaggle/working/OutputCodeTA', 'urutan_target_dict.pkl')
with open(output_file, 'wb') as f:
    pickle.dump(urutan_target_dict, f)

In [23]:
urutan_target_dictval = {}

for image_filename in image_filenames_val:
    captions = []
    image = image_filename.split('.')[0]
    caption = images_captions_dict[image_filename]

    caption = remove_start_and_end_tokens(caption)
    captions.append(caption)
    
    caption_target = tokenizer.texts_to_sequences(captions)

    caption_target = tf.keras.preprocessing.sequence.pad_sequences(caption_target, padding='post', maxlen=max_caption_length)
    
    if image_filename not in urutan_target_dictval:
        urutan_target_dictval[image] = caption_target

In [24]:
import pickle

output_file = os.path.join('/kaggle/working/OutputCodeTA', 'urutan_target_dictval.pkl')
with open(output_file, 'wb') as f:
    pickle.dump(urutan_target_dictval, f)

In [25]:
output_file = os.path.join('/kaggle/working/OutputCodeTA', 'image_filenames_train_id.txt')
with open(output_file, "w") as file:
    for image_filename in image_filenames_train:
        image_filename = image_filename.split('.')[0]
        file.write(image_filename + "\n")

In [87]:
output_file = os.path.join('/kaggle/working/OutputCodeTA', 'image_filenames_trainval_id.txt')
with open(output_file, "w") as file:
    i = 0
    for image_filename in image_filenames_train:
        image_filename = image_filename.split('.')[0]
        file.write(image_filename + "\n")
        if i == 100:
            break
        i += 1

In [26]:
output_file = os.path.join('/kaggle/working/OutputCodeTA', 'image_filenames_val_id.txt')
with open(output_file, "w") as file:
    for image_filename in image_filenames_val:
        image_filename = image_filename.split('.')[0]
        file.write(image_filename + "\n")

In [27]:
output_file = os.path.join('/kaggle/working/OutputCodeTA', 'image_filenames_test_id.txt')
with open(output_file, "w") as file:
    for image_filename in image_filenames_test:
        image_filename = image_filename.split('.')[0]
        file.write(image_filename + "\n")

In [88]:
import os
import os.path as osp
import numpy as np
from easydict import EasyDict as edict

__C = edict()

cfg = __C

# ---------------------------------------------------------------------------- #
# Training options
# ---------------------------------------------------------------------------- #
__C.TRAIN = edict()

# Minibatch size
__C.TRAIN.BATCH_SIZE = 10

# scheduled sampling
__C.TRAIN.SCHEDULED_SAMPLING = edict()

__C.TRAIN.SCHEDULED_SAMPLING.START = 6

__C.TRAIN.SCHEDULED_SAMPLING.INC_EVERY = 5

__C.TRAIN.SCHEDULED_SAMPLING.INC_PROB = 0.05

__C.TRAIN.SCHEDULED_SAMPLING.MAX_PROB = 0.5

# ---------------------------------------------------------------------------- #
# Inference ('test') options
# ---------------------------------------------------------------------------- #
__C.TEST = edict()

# Minibatch size
__C.TEST.BATCH_SIZE = 16


# ---------------------------------------------------------------------------- #
# Data loader options
# ---------------------------------------------------------------------------- #
__C.DATA_LOADER = edict()

# Data directory
__C.DATA_LOADER.NUM_WORKERS = 3

__C.DATA_LOADER.PIN_MEMORY = True

__C.DATA_LOADER.DROP_LAST = True

__C.DATA_LOADER.SHUFFLE = True

__C.DATA_LOADER.TRAIN_GV_FEAT = ''

__C.DATA_LOADER.TRAIN_ATT_FEATS = '/kaggle/working/OutputCodeTA/image_feature152'

__C.DATA_LOADER.VAL_GV_FEAT = ''

__C.DATA_LOADER.VAL_ATT_FEATS = '/kaggle/working/OutputCodeTA/image_feature152'

__C.DATA_LOADER.TEST_GV_FEAT = ''

__C.DATA_LOADER.TEST_ATT_FEATS = '/kaggle/working/OutputCodeTA/image_feature152'

__C.DATA_LOADER.TRAIN_ID = '/kaggle/working/OutputCodeTA/image_filenames_train_id.txt'

__C.DATA_LOADER.VALT_ID = '/kaggle/working/OutputCodeTA/image_filenames_trainval_id.txt'

__C.DATA_LOADER.VAL_ID = '/kaggle/working/OutputCodeTA/image_filenames_val_id.txt'

__C.DATA_LOADER.TEST_ID = '/kaggle/working/OutputCodeTA/image_filenames_test_id.txt'

__C.DATA_LOADER.INPUT_SEQ_PATH = '/kaggle/working/OutputCodeTA/urutan_input_dict.pkl'

__C.DATA_LOADER.TARGET_SEQ_PATH = '/kaggle/working/OutputCodeTA/urutan_target_dict.pkl'

__C.DATA_LOADER.TARGETVAL_SEQ_PATH = '/kaggle/working/OutputCodeTA/urutan_target_dictval.pkl'

__C.DATA_LOADER.SEQ_PER_IMG = 1

__C.DATA_LOADER.MAX_FEAT = -1

# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
__C.MODEL = edict()

__C.MODEL.TYPE = 'XLAN'            

__C.MODEL.SEQ_LEN = 19             

__C.MODEL.VOCAB_SIZE = 96        

__C.MODEL.WORD_EMBED_DIM = 256 #1024

__C.MODEL.WORD_EMBED_ACT = 'CELU'       

__C.MODEL.WORD_EMBED_NORM = False

__C.MODEL.DROPOUT_WORD_EMBED = 0.5 #0.5

__C.MODEL.GVFEAT_DIM = 2048

__C.MODEL.GVFEAT_EMBED_DIM = -1

__C.MODEL.GVFEAT_EMBED_ACT = 'NONE'     

__C.MODEL.DROPOUT_GV_EMBED = 0.0

__C.MODEL.ATT_FEATS_DIM = 2048

__C.MODEL.ATT_FEATS_EMBED_DIM = 1024

__C.MODEL.ATT_FEATS_EMBED_ACT = 'CELU'   

__C.MODEL.DROPOUT_ATT_EMBED = 0.5

__C.MODEL.ATT_FEATS_NORM = False

__C.MODEL.ATT_HIDDEN_SIZE = -1

__C.MODEL.ATT_HIDDEN_DROP = 0.0

__C.MODEL.ATT_ACT = 'TANH'  

__C.MODEL.RNN_SIZE = 1024

__C.MODEL.DROPOUT_LM = 0.5 


# Bilinear
__C.MODEL.BILINEAR = edict()

__C.MODEL.BILINEAR.DIM = 1024

__C.MODEL.BILINEAR.DECODE_DIM = 1024

__C.MODEL.BILINEAR.ENCODE_ATT_MID_DIM = [128, 64, 128]

__C.MODEL.BILINEAR.DECODE_ATT_MID_DIM = [128, 64, 128]

__C.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT = 0.1

__C.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT = 0.1

__C.MODEL.BILINEAR.ATT_DIM = 1024

__C.MODEL.BILINEAR.ACT = 'CELU'  

__C.MODEL.BILINEAR.ENCODE_DROPOUT = 0.5

__C.MODEL.BILINEAR.DECODE_DROPOUT = 0.5 

__C.MODEL.BILINEAR.ENCODE_LAYERS = 4        # Orde 8, ubah jumlah encode layer jika ingin mengganti orde

__C.MODEL.BILINEAR.DECODE_LAYERS = 1

__C.MODEL.BILINEAR.TYPE = 'LowRank'

__C.MODEL.BILINEAR.ATTTYPE = 'SCAtt'

__C.MODEL.BILINEAR.HEAD = 8

__C.MODEL.BILINEAR.DECODE_HEAD = 8

__C.MODEL.BILINEAR.ENCODE_FF_DROPOUT = 0.1

__C.MODEL.BILINEAR.DECODE_FF_DROPOUT = 0.1

__C.MODEL.BILINEAR.ENCODE_BLOCK = 'LowRankBilinearEnc'

__C.MODEL.BILINEAR.DECODE_BLOCK = 'LowRankBilinearDec'

__C.MODEL.BILINEAR.ELU_ALPHA = 1.3

__C.MODEL.BILINEAR.BIFEAT_EMB_ACT = 'RELU'

__C.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT = 0.3

__C.MODEL.BILINEAR.DECODE_BIFEAT_EMB_DROPOUT = 0.3

# ---------------------------------------------------------------------------- #
# Solver options
# ---------------------------------------------------------------------------- #
__C.SOLVER = edict()

# Base learning rate for the specified schedule
__C.SOLVER.BASE_LR = 0.00004

# Solver type
__C.SOLVER.TYPE = 'ADAM'                 # 'ADAM', 'ADAMAX', 'SGD'

# Maximum number of SGD iterations
__C.SOLVER.MAX_EPOCH = 70

__C.SOLVER.MAX_ITER = -1


# L2 regularization hyperparameter
__C.SOLVER.WEIGHT_DECAY = 0.0001

__C.SOLVER.WEIGHT_DECAY_BIAS = 0.0

__C.SOLVER.BIAS_LR_FACTOR = 1

__C.SOLVER.DISPLAY = 20

__C.SOLVER.TEST_INTERVAL = 1

__C.SOLVER.SNAPSHOT_ITERS = 3

# SGD
__C.SOLVER.SGD = edict()
__C.SOLVER.SGD.MOMENTUM = 0.95

# ADAM
__C.SOLVER.ADAM = edict()
__C.SOLVER.ADAM.BETAS = [0.9, 0.999]
__C.SOLVER.ADAM.EPS = 1e-8

# ---------------------------------------------------------------------------- #
# Losses options
# ---------------------------------------------------------------------------- #
__C.LOSSES = edict()

__C.LOSSES.XE_TYPE = 'CrossEntropy'   


# ---------------------------------------------------------------------------- #
# PARAM options
# ---------------------------------------------------------------------------- #
__C.PARAM = edict()

__C.PARAM.WT = 'WT'

__C.PARAM.GLOBAL_FEAT = 'GV_FEAT'

__C.PARAM.ATT_FEATS = 'ATT_FEATS'

__C.PARAM.ATT_FEATS_MASK = 'ATT_FEATS_MASK'

__C.PARAM.P_ATT_FEATS = 'P_ATT_FEATS'

__C.PARAM.STATE = 'STATE'

__C.PARAM.INPUT_SENT = 'INPUT_SENT'

__C.PARAM.TARGET_SENT = 'TARGET_SENT'

__C.PARAM.INDICES = 'INDICES'

# ---------------------------------------------------------------------------- #
# Inference options
# ---------------------------------------------------------------------------- #
__C.INFERENCE = edict()

__C.INFERENCE.VOCAB = '/kaggle/working/OutputCodeTA/data_vocabulary.txt'

__C.INFERENCE.BEAM_SIZE = 3

__C.INFERENCE.GREEDY_DECODE = True # Greedy decode or sample decode


In [29]:
import torch.nn as nn

def activation(act):
    if act == 'RELU':
        return nn.ReLU()
    elif act == 'TANH':
        return nn.Tanh()
    elif act == 'GLU':
        return nn.GLU()
    elif act == 'ELU':
        return nn.ELU(cfg.MODEL.BILINEAR.ELU_ALPHA)
    elif act == 'CELU':
        return nn.CELU(cfg.MODEL.BILINEAR.ELU_ALPHA)
    else:
        return nn.Identity()

def expand_tensor(tensor, size, dim=1):
    if size == 1 or tensor is None:
        return tensor
    tensor = tensor.unsqueeze(dim)
    tensor = tensor.expand(list(tensor.shape[:dim]) + [size] + list(tensor.shape[dim+1:])).contiguous()
    tensor = tensor.view(list(tensor.shape[:dim-1]) + [-1] + list(tensor.shape[dim+1:]))
    return tensor

def load_ids(path):
    with open(path, 'r') as fid:
        lines = [line.strip() for line in fid]
    return lines

def load_lines(path):
    with open(path, 'r') as fid:
        lines = [line.strip() for line in fid]
    return lines

def load_vocab(path):
    vocab = ['.']
    with open(path, 'r') as fid:
        for line in fid:
            vocab.append(line.strip())
    return vocab

def decode_sequence(vocab, seq):
    N, T = seq.size()
    sents = []
    for n in range(N):
        words = []
        for t in range(T):
            ix = seq[n, t]
            if ix == 0:
                break
            if ix == 4:
                break
            words.append(vocab[ix])
        sent = ' '.join(words)
        sents.append(sent)
    return sents

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Optimizer(nn.Module):
    def __init__(self, model):
        super(Optimizer, self).__init__()
        self.setup_optimizer(model)

    def setup_optimizer(self, model):
        params = []
        for key, value in model.named_parameters():
            if not value.requires_grad:
                continue
            lr = cfg.SOLVER.BASE_LR
            weight_decay = cfg.SOLVER.WEIGHT_DECAY
            if "bias" in key:
                lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 
                weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
            params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

        if cfg.SOLVER.TYPE == 'SGD':
            self.optimizer = torch.optim.SGD(
                params, 
                lr = cfg.SOLVER.BASE_LR, 
                momentum = cfg.SOLVER.SGD.MOMENTUM
            )
        elif cfg.SOLVER.TYPE == 'ADAM':
            self.optimizer = torch.optim.Adam(
                params,
                lr = cfg.SOLVER.BASE_LR, 
                betas = cfg.SOLVER.ADAM.BETAS, 
                eps = cfg.SOLVER.ADAM.EPS
            )
        elif cfg.SOLVER.TYPE == 'ADAMAX':
            self.optimizer = torch.optim.Adamax(
                params,
                lr = cfg.SOLVER.BASE_LR, 
                betas = cfg.SOLVER.ADAM.BETAS, 
                eps = cfg.SOLVER.ADAM.EPS
            )
        else:
            raise NotImplementedError
            
        
        return self.optimizer
            
    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.optimizer.step()
    
    def scheduler1(self):
        return torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.8)
    
    def scheduler2(self):
        return torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4,15,25], gamma=0.1)
    
    def scheduler3(self):
        return torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

    def get_lr(self):
        lr = []
        for param_group in self.optimizer.param_groups:
            lr.append(param_group['lr'])
        lr = sorted(list(set(lr)))
        return lr

In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from functools import reduce

class BasicAtt(nn.Module):
    def __init__(self, mid_dims, mid_dropout):
        super(BasicAtt, self).__init__()
        sequential = []
        for i in range(1, len(mid_dims) - 1):
            sequential.append(nn.Linear(mid_dims[i - 1], mid_dims[i]))
            sequential.append(nn.ReLU())
            if mid_dropout > 0:
                sequential.append(nn.Dropout(mid_dropout))
        self.attention_basic = nn.Sequential(*sequential) if len(sequential) > 0 else None
        self.attention_last = nn.Linear(mid_dims[-2], mid_dims[-1])

    def forward(self, att_map, att_mask, value1, value2):
        if self.attention_basic is not None:
            att_map = self.attention_basic(att_map)
        attn_weights = self.attention_last(att_map)
        attn_weights = attn_weights.squeeze(-1)
        if att_mask is not None:
            attn_weights = attn_weights.masked_fill(att_mask.unsqueeze(1) == 0, -1e9)
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        attn = torch.matmul(attn_weights.unsqueeze(-2), value2).squeeze(-2)
        return attn
    
class SCAtt(BasicAtt):
    def __init__(self, mid_dims, mid_dropout):
        super(SCAtt, self).__init__(mid_dims, mid_dropout)
        self.attention_last = nn.Linear(mid_dims[-2], 1)
        self.attention_last2 = nn.Linear(mid_dims[-2], mid_dims[-1])

    def forward(self, att_map, att_mask, value1, value2):
        if self.attention_basic is not None:
            att_map = self.attention_basic(att_map)

        if att_mask is not None:
            att_mask = att_mask.unsqueeze(1)
            att_mask_ext = att_mask.unsqueeze(-1)
            att_map_pool = torch.sum(att_map * att_mask_ext, -2) / torch.sum(att_mask_ext, -2)
        else:
            att_map_pool = att_map.mean(-2)

        alpha_spatial = self.attention_last(att_map)
        alpha_channel = self.attention_last2(att_map_pool)
        alpha_channel = torch.sigmoid(alpha_channel)

        alpha_spatial = alpha_spatial.squeeze(-1)
        if att_mask is not None:
            alpha_spatial = alpha_spatial.masked_fill(att_mask == 0, -1e9)
        alpha_spatial = F.softmax(alpha_spatial, dim=-1)

        if len(alpha_spatial.shape) == 4: # batch_size * head_num * seq_num * seq_num (for xtransformer)
            value2 = torch.matmul(alpha_spatial, value2)
        else:
            value2 = torch.matmul(alpha_spatial.unsqueeze(-2), value2).squeeze(-2)

        attn = value1 * value2 * alpha_channel

        return attn
    
class LowRank(nn.Module):
    def __init__(self, embed_dim, att_type, att_heads, att_mid_dim, att_mid_drop):
        super(LowRank, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = att_heads
        self.head_dim = embed_dim // self.num_heads
        self.scaling = self.head_dim ** -0.5
        output_dim = 2 * embed_dim if cfg.MODEL.BILINEAR.ACT == 'GLU' else embed_dim #CELU

        sequential = []
        sequential.append(nn.Linear(embed_dim, output_dim))
        act = activation(cfg.MODEL.BILINEAR.ACT)
        if act is not None:
            sequential.append(act)
        sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
        self.in_proj_q = nn.Sequential(*sequential)

        sequential = []
        sequential.append(nn.Linear(embed_dim, output_dim))
        act = activation(cfg.MODEL.BILINEAR.ACT)
        if act is not None:
            sequential.append(act)
        sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
        self.in_proj_k = nn.Sequential(*sequential)
        
        sequential = []
        sequential.append(nn.Linear(embed_dim, output_dim))
        act = activation(cfg.MODEL.BILINEAR.ACT)
        if act is not None:
            sequential.append(act)
        sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
        self.in_proj_v1 = nn.Sequential(*sequential)

        sequential = []
        sequential.append(nn.Linear(embed_dim, output_dim))
        act = activation(cfg.MODEL.BILINEAR.ACT)
        if act is not None:
            sequential.append(act)
        sequential.append(torch.nn.GroupNorm(self.num_heads, embed_dim))
        self.in_proj_v2 = nn.Sequential(*sequential)

        self.attn_net = create_layer(att_type, att_mid_dim, att_mid_drop)
        self.clear_buffer() 

    def apply_to_states(self, fn):
        self.buffer_keys = fn(self.buffer_keys)
        self.buffer_value2 = fn(self.buffer_value2)

    def init_buffer(self, batch_size):
        self.buffer_keys = torch.zeros((batch_size, self.num_heads, 0, self.head_dim))
        self.buffer_value2 = torch.zeros((batch_size, self.num_heads, 0, self.head_dim))

    def clear_buffer(self):
        self.buffer_keys = None
        self.buffer_value2 = None

    def forward(self, query, key, mask, value1, value2, precompute=False):
        batch_size = query.size()[0]
        q = self.in_proj_q(query)
        v1 = self.in_proj_v1(value1)

        q = q.view(batch_size, self.num_heads, self.head_dim)
        v1 = v1.view(batch_size, self.num_heads, self.head_dim)

        if precompute == False:
            key = key.view(-1, key.size()[-1])
            value2 = value2.view(-1, value2.size()[-1])
            k = self.in_proj_k(key)
            v2 = self.in_proj_v2(value2)
            
            k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
            v2 = v2.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        else:
            k = key
            v2 = value2

        attn_map = q.unsqueeze(-2) * k
        attn = self.attn_net(attn_map, mask, v1, v2)
        attn = attn.view(batch_size, self.num_heads * self.head_dim)
        return attn

    def forward2(self, query, key, mask, value1, value2, precompute=False):
        batch_size = query.size()[0]
        query = query.view(-1, query.size()[-1])
        value1 = value1.view(-1, value1.size()[-1])
        
        q = self.in_proj_q(query)
        v1 = self.in_proj_v1(value1)

        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v1 = v1.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        if precompute == False:
            key = key.view(-1, key.size()[-1])
            value2 = value2.view(-1, value2.size()[-1])
            k = self.in_proj_k(key)
            v2 = self.in_proj_v2(value2)
            k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
            v2 = v2.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

            if self.buffer_keys is not None and self.buffer_value2 is not None:
                self.buffer_keys = torch.cat([self.buffer_keys, k], dim=2)
                self.buffer_value2 = torch.cat([self.buffer_value2, v2], dim=2)
                k = self.buffer_keys
                v2 = self.buffer_value2
        else:
            k = key
            v2 = value2
        
        attn_map = q.unsqueeze(-2) * k.unsqueeze(-3)
        attn = self.attn_net.forward(attn_map, mask, v1, v2).transpose(1, 2).contiguous()
        attn = attn.view(batch_size, -1, self.num_heads * self.head_dim)
        return attn

    def precompute(self, key, value2):
        batch_size = value2.size()[0]
        key = key.view(-1, key.size()[-1])
        value2 = value2.view(-1, value2.size()[-1])
    
        k = self.in_proj_k(key)
        
        v2 = self.in_proj_v2(value2)

        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v2 = v2.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        return k, v2
    
__factory_layer = {
    'LowRank': LowRank,
    'BasicAtt': BasicAtt,
    'SCAtt': SCAtt,
}

def names_layer():
    return sorted(__factory_layer.keys())

def create_layer(name, *args, **kwargs):
    if name not in __factory_layer:
        raise KeyError("Unknown layer:", name)
    return __factory_layer[name](*args, **kwargs)


class LowRankBilinearLayer(nn.Module):
    def __init__(self, embed_dim, att_type, att_heads,
        att_mid_dim, att_mid_drop, dropout):
        super(LowRankBilinearLayer, self).__init__()
        self.encoder_attn = LowRank(
            embed_dim = embed_dim, 
            att_type = att_type, 
            att_heads = att_heads, 
            att_mid_dim = att_mid_dim, 
            att_mid_drop = att_mid_drop
        )
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x, key=None, mask=None, 
        value1=None, value2=None, precompute=False):    
        x = self.encoder_attn(
            query=x,
            key=key if key is not None else x,
            mask=mask,
            value1=value1 if value1 is not None else x,
            value2=value2 if value2 is not None else x,
            precompute=precompute
        )
        if self.dropout is not None:
            x = self.dropout(x)
        return x

    def precompute(self, key, value2):
        return self.encoder_attn.precompute(key, value2)

class LowRankBilinearEncBlock(nn.Module):
    def __init__(self, embed_dim, att_type, att_heads, att_mid_dim,
        att_mid_drop, dropout, layer_num):
        super(LowRankBilinearEncBlock, self).__init__()
        
        self.layers = nn.ModuleList([])
        self.bifeat_emb = nn.ModuleList([])
        self.layer_norms = nn.ModuleList([]) 
        for _ in range(layer_num):
            sublayer = LowRankBilinearLayer( 
                embed_dim = embed_dim, 
                att_type = att_type,
                att_heads = att_heads,
                att_mid_dim = att_mid_dim,
                att_mid_drop = att_mid_drop,
                dropout = dropout
            )
            self.layers.append(sublayer)

            self.bifeat_emb.append(nn.Sequential(
                nn.Linear(2 * embed_dim, embed_dim),
                activation(cfg.MODEL.BILINEAR.BIFEAT_EMB_ACT),
                nn.Dropout(cfg.MODEL.BILINEAR.ENCODE_BIFEAT_EMB_DROPOUT)
            ))

            self.layer_norms.append(torch.nn.LayerNorm(embed_dim))
        
        self.projj = nn.Linear(embed_dim, cfg.MODEL.BILINEAR.DECODE_DIM)
        self.proj = nn.Linear(embed_dim * (layer_num + 1), cfg.MODEL.BILINEAR.DECODE_DIM)
        self.layer_norm = torch.nn.LayerNorm(cfg.MODEL.BILINEAR.DECODE_DIM)
        
    def forward(self, gv_feat, att_feats, att_mask, p_att_feats=None):
        if gv_feat.shape[-1] == 1:  # empty gv_feat
            gv_feat = torch.sum(att_feats * att_mask.unsqueeze(-1), 1) / torch.sum(att_mask.unsqueeze(-1), 1)
        
        feat_arr = [gv_feat]
        for i, layer in enumerate(self.layers):
            gv_feat = layer(gv_feat, att_feats, att_mask, gv_feat, att_feats)
            att_feats_cat = torch.cat([gv_feat.unsqueeze(1).expand_as(att_feats), att_feats], dim = -1)
            att_feats = self.bifeat_emb[i](att_feats_cat) + att_feats
            att_feats = self.layer_norms[i](att_feats)
            feat_arr.append(gv_feat)
            
        att_feats = self.projj(att_feats)
        gv_feat = torch.cat(feat_arr, dim=-1)
        gv_feat = self.proj(gv_feat)
        gv_feat = self.layer_norm(gv_feat)
   
        return gv_feat, att_feats

class LowRankBilinearDecBlock(nn.Module):
    def __init__(self, embed_dim, att_type, att_heads,
        att_mid_dim, att_mid_drop, dropout, layer_num):
        super(LowRankBilinearDecBlock, self).__init__()
        
        self.layers = nn.ModuleList([])
        for _ in range(layer_num):
            sublayer = LowRankBilinearLayer( 
                embed_dim = embed_dim, 
                att_type = att_type,
                att_heads = att_heads,
                att_mid_dim = att_mid_dim,
                att_mid_drop = att_mid_drop,
                dropout = dropout
            )
            self.layers.append(sublayer)
        
        self.proj = nn.Linear(embed_dim * (layer_num + 1), embed_dim)
        self.layer_norm = torch.nn.LayerNorm(cfg.MODEL.BILINEAR.DECODE_DIM)
        
    def precompute(self, key, value2):
        keys = []
        value2s = []
        for layer in self.layers:
            k, v = layer.precompute(key, value2)
            keys.append(k)
            value2s.append(v)

        return torch.cat(keys, dim=-1), torch.cat(value2s, dim=-1)

    def forward(self, gv_feat, att_feats, att_mask, p_att_feats=None, precompute=False):
        if precompute == True:
            dim = p_att_feats.size()[-1]
            keys = p_att_feats.narrow(-1, 0, dim // 2)
            value2s = p_att_feats.narrow(-1, dim // 2, dim // 2)
            dim = keys.size()[-1] // len(self.layers)
    
        if gv_feat.shape[-1] == 1:  # empty gv_feat
            if att_mask is not None:
                gv_feat = (torch.sum(att_feats * att_mask.unsqueeze(-1), 1) / torch.sum(att_mask.unsqueeze(-1), 1))
            else:
                gv_feat = torch.mean(att_feats, 1)

        feat_arr = [gv_feat]
        for i, layer in enumerate(self.layers):
            key = keys.narrow(-1, i * dim, dim) if precompute else att_feats
            value2 = value2s.narrow(-1, i * dim, dim) if precompute else att_feats
            gv_feat = layer(gv_feat, key, att_mask, gv_feat, value2, precompute)
            feat_arr.append(gv_feat)

        gv_feat = torch.cat(feat_arr, dim=-1)
        gv_feat = self.proj(gv_feat)
        gv_feat = self.layer_norm(gv_feat)
    
        return gv_feat, att_feats
    
class FeedForwardBlock(nn.Module):
    def __init__(self, embed_dim, ffn_embed_dim, 
        relu_dropout, dropout):
        super(FeedForwardBlock, self).__init__()

        self.fc1 = nn.Linear(embed_dim, ffn_embed_dim)
        self.fc2 = nn.Linear(ffn_embed_dim, embed_dim)
        self.dropout = dropout
        self.relu_dropout = relu_dropout
        self.layer_norms = torch.nn.LayerNorm(embed_dim)

    def forward(self, x):
        residual = x
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.layer_norms(x)
        return x
    
__factory_block = {
    'FeedForward': FeedForwardBlock,
    'LowRankBilinearEnc': LowRankBilinearEncBlock,
    'LowRankBilinearDec': LowRankBilinearDecBlock,
}

def names_block():
    return sorted(__factory_block.keys())

def create_block(name, *args, **kwargs):
    if name not in __factory_block:
        raise KeyError("Unknown blocks:", name)
    return __factory_block[name](*args, **kwargs)


class BasicModel(nn.Module):
    def __init__(self):
        super(BasicModel, self).__init__()

    def select(self, batch_size, beam_size, t, candidate_logprob):
        selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True)
        selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size]
        return selected_idx, selected_logprob

    def beam_search(self, init_state, init_logprobs, **kwargs):
        all_logprobs = []
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
        
            ys,ix = torch.sort(logprobsf,1,True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols): # for each column (word, essentially)
                for q in range(rows): # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q,c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
                    candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob})
            candidates = sorted(candidates,  key=lambda x: -x['p'])
            
            new_state = [_.clone() for _ in state]
            if t >= 1:
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
                for state_ix in range(len(new_state)):
                    new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
                beam_seq[t, vix] = v['c'] # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
                beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
            state = new_state
            return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates

        beam_size = kwargs['BEAM_SIZE']
        group_size = 1 #kwargs['GROUP_SIZE']
        diversity_lambda = 0.5 #kwargs['DIVERSITY_LAMBDA']
        constraint = False #kwargs['CONSTRAINT']
        max_ppl = False #kwargs['MAX_PPL']
        bdash = beam_size // group_size
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        beam_seq_table = [torch.LongTensor(cfg.MODEL.SEQ_LEN, bdash).zero_() for _ in range(group_size)]
        beam_seq_logprobs_table = [torch.FloatTensor(cfg.MODEL.SEQ_LEN, bdash).zero_() for _ in range(group_size)]
        beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]

        done_beams_table = [[] for _ in range(group_size)]
        state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
        logprobs_table = list(init_logprobs.chunk(group_size, 0))

        for t in range(cfg.MODEL.SEQ_LEN + group_size - 1):
            for divm in range(group_size): 
                if t >= divm and t <= cfg.MODEL.SEQ_LEN + divm - 1:
                    logprobsf = logprobs_table[divm].data.float()
                    if constraint and t-divm > 0:
                        logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(device), float('-inf'))
                    logprobsf[:,logprobsf.size(1)-1] -= 1000  

                    unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)

                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    state_table[divm],\
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    all_logprobs.append(logprobsf)
                    for vix in range(bdash):
                        if beam_seq_table[divm][t-divm,vix] == 0 or t == cfg.MODEL.SEQ_LEN + divm - 1:
                            final_beam = {
                                'seq': beam_seq_table[divm][:, vix].clone(), 
                                'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
                                'p': beam_logprobs_sum_table[divm][vix].item()
                            }
                            if max_ppl:
                                final_beam['p'] = final_beam['p'] / (t-divm+1)
                            done_beams_table[divm].append(final_beam)
                            beam_logprobs_sum_table[divm][vix] = -1000

                    wt = beam_seq_table[divm][t-divm]
                    kwargs[cfg.PARAM.WT] = wt.to(device) #cuda()
                    kwargs[cfg.PARAM.STATE] = state_table[divm]
                    logprobs_table[divm], state_table[divm], _ = self.get_logprobs_state(**kwargs)

        done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
        done_beams = reduce(lambda a,b:a+b, done_beams_table)
        return done_beams, all_logprobs


class AttBasicModel(BasicModel):
    def __init__(self):
        super(AttBasicModel, self).__init__()
        self.ss_prob = 0.0                               # Schedule sampling probability
        self.vocab_size = cfg.MODEL.VOCAB_SIZE + 1      
        self.att_dim = cfg.MODEL.ATT_FEATS_EMBED_DIM \
            if cfg.MODEL.ATT_FEATS_EMBED_DIM > 0 else cfg.MODEL.ATT_FEATS_DIM 

        # word embed
        sequential = [nn.Embedding(self.vocab_size, cfg.MODEL.WORD_EMBED_DIM)]
        sequential.append(activation(cfg.MODEL.WORD_EMBED_ACT))
        if cfg.MODEL.WORD_EMBED_NORM == True: 
            sequential.append(nn.LayerNorm(cfg.MODEL.WORD_EMBED_DIM))
        if cfg.MODEL.DROPOUT_WORD_EMBED > 0:
            sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_WORD_EMBED))
        self.word_embed = nn.Sequential(*sequential)

        # global visual feat embed
        sequential = []
        if cfg.MODEL.GVFEAT_EMBED_DIM > 0: 
            sequential.append(nn.Linear(cfg.MODEL.GVFEAT_DIM, cfg.MODEL.GVFEAT_EMBED_DIM))
        sequential.append(activation(cfg.MODEL.GVFEAT_EMBED_ACT))
        if cfg.MODEL.DROPOUT_GV_EMBED > 0: 
            sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_GV_EMBED))
        self.gv_feat_embed = nn.Sequential(*sequential) if len(sequential) > 0 else None

        # attention feats embed
        sequential = []
        if cfg.MODEL.ATT_FEATS_EMBED_DIM > 0:
            sequential.append(nn.Linear(cfg.MODEL.ATT_FEATS_DIM, cfg.MODEL.ATT_FEATS_EMBED_DIM))
        sequential.append(activation(cfg.MODEL.ATT_FEATS_EMBED_ACT))
        if cfg.MODEL.DROPOUT_ATT_EMBED > 0:
            sequential.append(nn.Dropout(cfg.MODEL.DROPOUT_ATT_EMBED))
        if cfg.MODEL.ATT_FEATS_NORM == True: 
            sequential.append(torch.nn.LayerNorm(cfg.MODEL.ATT_FEATS_EMBED_DIM))
        self.att_embed = nn.Sequential(*sequential) if len(sequential) > 0 else None

        self.dropout_lm  = nn.Dropout(cfg.MODEL.DROPOUT_LM) if cfg.MODEL.DROPOUT_LM > 0 else None
        self.logit = nn.Linear(cfg.MODEL.RNN_SIZE, self.vocab_size)
        self.p_att_feats = nn.Linear(self.att_dim, cfg.MODEL.ATT_HIDDEN_SIZE) \
            if cfg.MODEL.ATT_HIDDEN_SIZE > 0 else None 

        # bilinear
        if cfg.MODEL.BILINEAR.DIM > 0:
            self.p_att_feats = None
            self.encoder_layers = create_block(
                cfg.MODEL.BILINEAR.ENCODE_BLOCK,
                embed_dim = cfg.MODEL.BILINEAR.DIM, 
                att_type = cfg.MODEL.BILINEAR.ATTTYPE,
                att_heads = cfg.MODEL.BILINEAR.HEAD,
                att_mid_dim = cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DIM,
                att_mid_drop = cfg.MODEL.BILINEAR.ENCODE_ATT_MID_DROPOUT,
                dropout = cfg.MODEL.BILINEAR.ENCODE_DROPOUT, 
                layer_num = cfg.MODEL.BILINEAR.ENCODE_LAYERS
            )

    def init_hidden(self, batch_size):
        return [Variable(torch.zeros(self.num_layers, batch_size, cfg.MODEL.RNN_SIZE).cuda()),
                Variable(torch.zeros(self.num_layers, batch_size, cfg.MODEL.RNN_SIZE).cuda())]

    def make_kwargs(self, wt, gv_feat, att_feats, att_mask, p_att_feats, state, **kgs):
        kwargs = kgs
        kwargs[cfg.PARAM.WT] = wt
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
        kwargs[cfg.PARAM.ATT_FEATS] = att_feats
        kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
        kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats
        kwargs[cfg.PARAM.STATE] = state
        return kwargs

    def preprocess(self, **kwargs):
        gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT]
        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]

        # embed gv_feat
        if self.gv_feat_embed is not None:
            gv_feat = self.gv_feat_embed(gv_feat)
        
        # embed att_feats
        if self.att_embed is not None:    
            att_feats = self.att_embed(att_feats)

        p_att_feats = self.p_att_feats(att_feats) if self.p_att_feats is not None else None

        # bilinear
        if cfg.MODEL.BILINEAR.DIM > 0:
            gv_feat, att_feats = self.encoder_layers(gv_feat, att_feats, att_mask)
            keys, value2s = self.attention.precompute(att_feats, att_feats)
            p_att_feats = torch.cat([keys, value2s], dim=-1)

        return gv_feat, att_feats, att_mask, p_att_feats

    def forward(self, **kwargs): 
        seq = kwargs[cfg.PARAM.INPUT_SENT]
        gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)
        gv_feat = expand_tensor(gv_feat, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_feats = expand_tensor(att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)
        att_mask = expand_tensor(att_mask, cfg.DATA_LOADER.SEQ_PER_IMG)
        p_att_feats = expand_tensor(p_att_feats, cfg.DATA_LOADER.SEQ_PER_IMG)

        batch_size = gv_feat.size(0)
        state = self.init_hidden(batch_size)

        outputs = Variable(torch.zeros(batch_size, seq.size(1), self.vocab_size).cuda())
        for t in range(seq.size(1)):
            if self.training and t >=1 and self.ss_prob > 0:
                prob = torch.empty(batch_size).cuda().uniform_(0, 1)
                mask = prob < self.ss_prob
                if mask.sum() == 0:
                    wt = seq[:,t].clone()
                else:
                    ind = mask.nonzero().view(-1)
                    wt = seq[:, t].data.clone()
                    prob_prev = torch.exp(outputs[:, t-1].detach())
                    wt = wt.long()
                    wt.index_copy_(0, ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, ind))
            else:
                wt = seq[:,t].clone()

            if t >= 1 and seq[:, t].max() == 0:
                break
            
            kwargs = self.make_kwargs(wt, gv_feat, att_feats, att_mask, p_att_feats, state)
            output, state = self.Forward(**kwargs)
            if self.dropout_lm is not None:
                output = self.dropout_lm(output)

            logit = self.logit(output)
            outputs[:, t] = logit

        return outputs

    def get_logprobs_state(self, **kwargs):
        output, state = self.Forward(**kwargs)
        logprobs = F.log_softmax(self.logit(output), dim=1)
        outputs = self.logit(output)
        return logprobs, state, outputs
    
    def _expand_state(self, batch_size, beam_size, cur_beam_size, state, selected_beam):
        shape = [int(sh) for sh in state.shape]
        beam = selected_beam
        for _ in shape[2:]:
            beam = beam.unsqueeze(-1)
        beam = beam.unsqueeze(0)
        beam_long = beam.long()
        state = torch.gather(
            state.view(*([shape[0], batch_size, cur_beam_size] + shape[2:])), 2,
            beam_long.expand(*([shape[0], batch_size, beam_size] + shape[2:]))
        )
        state = state.view(*([shape[0], -1, ] + shape[2:]))
        return state

    def decode_beam(self, **kwargs):
        beam_size = kwargs['BEAM_SIZE']
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)
        batch_size = gv_feat.size(0)
    
        sents = Variable(torch.zeros((cfg.MODEL.SEQ_LEN, batch_size), dtype=torch.long).to(device))#cuda())
        logprobs = Variable(torch.zeros(cfg.MODEL.SEQ_LEN, batch_size).to(device))#cuda())   
        self.done_beams = [[] for _ in range(batch_size)]
        all_logprobs_batch = []
        for n in range(batch_size):
            state = self.init_hidden(beam_size)
            gv_feat_beam = gv_feat[n:n+1].expand(beam_size, gv_feat.size(1)).contiguous()
            att_feats_beam = att_feats[n:n+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous()
            att_mask_beam = att_mask[n:n+1].expand(*((beam_size,)+att_mask.size()[1:]))
            p_att_feats_beam = p_att_feats[n:n+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() if p_att_feats is not None else None
    
            wt = Variable(torch.ones(beam_size, dtype=torch.long).to(device)) * 3
            kwargs = self.make_kwargs(wt, gv_feat_beam, att_feats_beam, att_mask_beam, p_att_feats_beam, state, **kwargs)
            logprobs_t, state, _ = self.get_logprobs_state(**kwargs)
    
            self.done_beams[n], all_logprobs = self.beam_search(state, logprobs_t, **kwargs)
            all_logprobs = torch.stack(all_logprobs)
            all_logprobs_batch.append(all_logprobs)
            all_logprobs_fix = torch.cat(all_logprobs_batch, dim=0)
            sents[:, n] = self.done_beams[n][0]['seq'] 
            logprobs[:, n] = self.done_beams[n][0]['logps']

        return sents.transpose(0, 1), logprobs.transpose(0, 1), all_logprobs_fix

    def decode(self, **kwargs):
        greedy_decode = kwargs['GREEDY_DECODE']
 
        gv_feat, att_feats, att_mask, p_att_feats = self.preprocess(**kwargs)
        batch_size = gv_feat.size(0)
        state = self.init_hidden(batch_size)

        sents = Variable(torch.zeros((batch_size, cfg.MODEL.SEQ_LEN), dtype=torch.long).cuda())
        logprobs = Variable(torch.zeros(batch_size, cfg.MODEL.SEQ_LEN).cuda())
        outputs = Variable(torch.zeros(batch_size, cfg.MODEL.SEQ_LEN, self.vocab_size).cuda())   
        wt = Variable(torch.ones(batch_size, dtype=torch.long).cuda()) * 3
        unfinished = wt.eq(wt)
        all_logprobs = []
        
        for t in range(cfg.MODEL.SEQ_LEN):
            kwargs = self.make_kwargs(wt, gv_feat, att_feats, att_mask, p_att_feats, state)
            logprobs_t, state, output = self.get_logprobs_state(**kwargs)
            all_logprobs.append(logprobs_t)
            
            if greedy_decode:
                logP_t, wt = torch.max(logprobs_t, 1)
            else:
                probs_t = torch.exp(logprobs_t)
                wt = torch.multinomial(probs_t, 1)
                logP_t = logprobs_t.gather(1, wt)

            wt = wt.view(-1).long()
            unfinished = unfinished * ((wt != 0) & (wt != 4))
            wt = wt * unfinished.type_as(wt)
            sents[:,t] = wt
            logprobs[:,t] = logP_t.view(-1)
            outputs[:,t] = output

            if unfinished.sum() == 0:
                break
  
        return sents, logprobs, all_logprobs, outputs


class XLAN(AttBasicModel):
    def __init__(self):
        super(XLAN, self).__init__()
        self.num_layers = 2

        rnn_input_size = cfg.MODEL.RNN_SIZE + cfg.MODEL.BILINEAR.DECODE_DIM
        self.att_lstm = nn.LSTMCell(rnn_input_size, cfg.MODEL.RNN_SIZE)
        self.ctx_drop = nn.Dropout(cfg.MODEL.DROPOUT_LM)

        self.attention = create_block(            
            cfg.MODEL.BILINEAR.DECODE_BLOCK, 
            embed_dim = cfg.MODEL.BILINEAR.DECODE_DIM, 
            att_type = cfg.MODEL.BILINEAR.ATTTYPE,
            att_heads = cfg.MODEL.BILINEAR.DECODE_HEAD,
            att_mid_dim = cfg.MODEL.BILINEAR.DECODE_ATT_MID_DIM,
            att_mid_drop = cfg.MODEL.BILINEAR.DECODE_ATT_MID_DROPOUT,
            dropout = cfg.MODEL.BILINEAR.DECODE_DROPOUT, 
            layer_num = cfg.MODEL.BILINEAR.DECODE_LAYERS
        )
        self.att2ctx = nn.Sequential(
            nn.Linear(cfg.MODEL.BILINEAR.DECODE_DIM + cfg.MODEL.RNN_SIZE, 2 * cfg.MODEL.RNN_SIZE), 
            nn.GLU()
        )

    def Forward(self, **kwargs):

        wt = kwargs[cfg.PARAM.WT]
        att_feats = kwargs[cfg.PARAM.ATT_FEATS]
        att_mask = kwargs[cfg.PARAM.ATT_FEATS_MASK]
        state = kwargs[cfg.PARAM.STATE]
        gv_feat = kwargs[cfg.PARAM.GLOBAL_FEAT]
        p_att_feats = kwargs[cfg.PARAM.P_ATT_FEATS]
        
        if gv_feat.shape[-1] == 1:  # empty gv_feat
            if att_mask is not None:
                gv_feat = torch.sum(att_feats * att_mask.unsqueeze(-1), 1) / torch.sum(att_mask.unsqueeze(-1), 1)
            else:
                gv_feat = torch.mean(att_feats, 1)

        xt = self.word_embed(wt)
        
        h_att, c_att = self.att_lstm(torch.cat([xt, gv_feat + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))

        att, _ = self.attention(h_att, att_feats, att_mask, p_att_feats, precompute=True)
        ctx_input = torch.cat([att, h_att], 1)

        output = self.att2ctx(ctx_input)
        state = [torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))]
        
        return output, state
    
__factory_model = {
    'XLAN': XLAN,
}

def names_model():
    return sorted(__factory_model.keys())

def create_model(name, *args, **kwargs):
    if name not in __factory_model:
        raise KeyError("Unknown caption model:", name)
    return __factory_model[name](*args, **kwargs)



In [34]:
import os
import random
import numpy as np
import torch
import torch.utils.data as data
import pickle

class CocoDataset(data.Dataset):
    def __init__(
        self, 
        image_ids_path, 
        input_seq, 
        target_seq,
        gv_feat_path, 
        att_feats_folder, 
        seq_per_img,
        max_feat_num
    ):
        self.max_feat_num = max_feat_num
        self.seq_per_img = seq_per_img
        self.image_ids = load_lines(image_ids_path)
        self.att_feats_folder = att_feats_folder if len(att_feats_folder) > 0 else None
        self.gv_feat = pickle.load(open(gv_feat_path, 'rb'), encoding='bytes') if len(gv_feat_path) > 0 else None

        if input_seq is not None:
            self.input_seq = pickle.load(open(input_seq, 'rb'), encoding='bytes')
            self.target_seq = pickle.load(open(target_seq, 'rb'), encoding='bytes')
            self.seq_len = len(self.input_seq[self.image_ids[0]][0,:])
        elif target_seq is not None:
            self.target_seq = pickle.load(open(target_seq, 'rb'), encoding='bytes')
            self.input_seq = None
            self.seq_len = -1
        else:
            self.seq_len = -1
            self.input_seq = None
            self.target_seq = None
         
    def set_seq_per_img(self, seq_per_img):
        self.seq_per_img = seq_per_img

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        indices = np.array([index]).astype('int')

        if self.gv_feat is not None:
            gv_feat = self.gv_feat[image_id]
            gv_feat = np.array(gv_feat).astype('float32')
        else:
            gv_feat = np.zeros((1,1))

        if self.att_feats_folder is not None:
            att_feats = np.load(os.path.join(self.att_feats_folder, str(image_id) + '.npz'))['feat']
            att_feats = np.array(att_feats).astype('float32')
        else:
            att_feats = np.zeros((1,1))
        
        if self.max_feat_num > 0 and att_feats.shape[0] > self.max_feat_num:
           att_feats = att_feats[:self.max_feat_num, :]

        if self.seq_len < 0:
            if self.target_seq is not None:
                target_seq = self.target_seq[image_id][0, :].reshape(1, -1)
                return indices, target_seq, gv_feat, att_feats
            else:
                return indices, gv_feat, att_feats

        input_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int')
        target_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int')
           
        n = len(self.input_seq[image_id])   
        if n >= self.seq_per_img:
            sid = 0
            ixs = random.sample(range(n), self.seq_per_img)                
        else:
            sid = n
            ixs = random.sample(range(n), self.seq_per_img - n)
            input_seq[0:n, :] = self.input_seq[image_id]
            target_seq[0:n, :] = self.target_seq[image_id]
           
        for i, ix in enumerate(ixs):
            input_seq[sid + i] = self.input_seq[image_id][ix,:]
            target_seq[sid + i] = self.target_seq[image_id][ix,:]
        return indices, input_seq, target_seq, gv_feat, att_feats
    
class CocoVal(data.Dataset):
    def __init__(
        self, 
        image_ids_path, 
        input_seq, 
        target_seq,
        gv_feat_path, 
        att_feats_folder, 
        seq_per_img,
        max_feat_num
    ):
        self.max_feat_num = max_feat_num
        self.seq_per_img = seq_per_img
        self.image_ids = load_lines(image_ids_path)
        self.att_feats_folder = att_feats_folder if len(att_feats_folder) > 0 else None
        self.gv_feat = pickle.load(open(gv_feat_path, 'rb'), encoding='bytes') if len(gv_feat_path) > 0 else None

        if input_seq is not None:
            self.input_seq = pickle.load(open(input_seq, 'rb'), encoding='bytes')
            self.target_seq = pickle.load(open(target_seq, 'rb'), encoding='bytes')
            self.seq_len = len(self.input_seq[self.image_ids[0]][0,:])
        elif target_seq is not None:
            self.target_seq = pickle.load(open(target_seq, 'rb'), encoding='bytes')
            self.input_seq = None
            self.seq_len = -1
        else:
            self.seq_len = -1
            self.input_seq = None
            self.target_seq = None
         
    def set_seq_per_img(self, seq_per_img):
        self.seq_per_img = seq_per_img

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        indices = np.array([index]).astype('int')

        if self.gv_feat is not None:
            gv_feat = self.gv_feat[image_id]
            gv_feat = np.array(gv_feat).astype('float32')
        else:
            gv_feat = np.zeros((1,1))

        if self.att_feats_folder is not None:
            att_feats = np.load(os.path.join(self.att_feats_folder, str(image_id) + '.npz'))['feat']
            att_feats = np.array(att_feats).astype('float32')
        else:
            att_feats = np.zeros((1,1))
        
        if self.max_feat_num > 0 and att_feats.shape[0] > self.max_feat_num:
           att_feats = att_feats[:self.max_feat_num, :]

        if self.seq_len < 0:
            if self.target_seq is not None:
                target_seq = self.target_seq[image_id][0, :].reshape(1, -1)
                return indices, target_seq, gv_feat, att_feats
            else:
                return indices, gv_feat, att_feats

        input_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int')
        target_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int')
           
        n = len(self.input_seq[image_id])   
        if n >= self.seq_per_img:
            sid = 0
            ixs = random.sample(range(n), self.seq_per_img)                
        else:
            sid = n
            ixs = random.sample(range(n), self.seq_per_img - n)
            input_seq[0:n, :] = self.input_seq[image_id]
            target_seq[0:n, :] = self.target_seq[image_id]
           
        for i, ix in enumerate(ixs):
            input_seq[sid + i] = self.input_seq[image_id][ix,:]
            target_seq[sid + i] = self.target_seq[image_id][ix,:]
        return indices, target_seq, gv_feat, att_feats

In [71]:
import numpy as np
import torch
import torch.nn as nn

def sample_collate(batch):
    batch = [(np.array(idx), np.array(inp), np.array(tgt), np.array(gv), np.array(att)) for idx, inp, tgt, gv, att in batch]
    indices, input_seq, target_seq, gv_feat, att_feats = zip(*batch)
    
    indices = np.stack(indices, axis=0).reshape(-1)
    input_seq = torch.cat([torch.from_numpy(b) for b in input_seq], 0)
    target_seq = torch.cat([torch.from_numpy(b) for b in target_seq], 0)
    gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)
    
    max_att_num = 1
    
    feat_arr = []
    mask_arr = []
    
    for i in range(len(att_feats)):
        tmp_feat = np.zeros((1, max_att_num, 2048), dtype=np.float32)
        tmp_feat[0, 0, :] = att_feats[i]
        feat_arr.append(torch.from_numpy(tmp_feat))
        
        tmp_mask = np.zeros((1, max_att_num), dtype=np.float32)
        tmp_mask[0, 0] = 1
        mask_arr.append(torch.from_numpy(tmp_mask))
    
    att_feats = torch.cat(feat_arr, 0)
    att_mask = torch.cat(mask_arr, 0)
    
    return indices, input_seq, target_seq, gv_feat, att_feats, att_mask

def sample_collate_val(batch):
    batch = [(np.array(idx), np.array(tgt), np.array(gv), np.array(att)) for idx, tgt, gv, att in batch]
    indices, valtarget_seq, gv_feat, att_feats = zip(*batch)
    
    indices = np.stack(indices, axis=0).reshape(-1)
    gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)
    valtarget_seq = torch.cat([torch.from_numpy(b) for b in valtarget_seq], 0)

    max_att_num = 1

    feat_arr = []
    mask_arr = []
    
    for i in range(len(att_feats)):
        tmp_feat = np.zeros((1, max_att_num, 2048), dtype=np.float32)
        tmp_feat[0, 0, :] = att_feats[i]
        feat_arr.append(torch.from_numpy(tmp_feat))
        
        tmp_mask = np.zeros((1, max_att_num), dtype=np.float32)
        tmp_mask[0, 0] = 1
        mask_arr.append(torch.from_numpy(tmp_mask))

    att_feats = torch.cat(feat_arr, 0)
    att_mask = torch.cat(mask_arr, 0)

    return indices, valtarget_seq, gv_feat, att_feats, att_mask

def sample_collate_test(batch):
    batch = [(np.array(idx), np.array(gv), np.array(att)) for idx, gv, att in batch]
    indices, gv_feat, att_feats = zip(*batch)
    
    indices = np.stack(indices, axis=0).reshape(-1)
    gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0)

    max_att_num = 1

    feat_arr = []
    mask_arr = []
    
    for i in range(len(att_feats)):
        tmp_feat = np.zeros((1, max_att_num, 2048), dtype=np.float32)
        tmp_feat[0, 0, :] = att_feats[i]
        feat_arr.append(torch.from_numpy(tmp_feat))
  
        tmp_mask = np.zeros((1, max_att_num), dtype=np.float32)
        tmp_mask[0, 0] = 1
        mask_arr.append(torch.from_numpy(tmp_mask))

    att_feats = torch.cat(feat_arr, 0)
    att_mask = torch.cat(mask_arr, 0)

    return indices, gv_feat, att_feats, att_mask

def load_train(coco_set):
    loader = torch.utils.data.DataLoader(
        coco_set, 
        batch_size = cfg.TRAIN.BATCH_SIZE,
        shuffle = True, 
        num_workers = 0, #cfg.DATA_LOADER.NUM_WORKERS, 
        drop_last = cfg.DATA_LOADER.DROP_LAST, 
        pin_memory = cfg.DATA_LOADER.PIN_MEMORY,
        sampler = None, 
        collate_fn = sample_collate
    )
    return loader

def load_test(image_ids_path, gv_feat_path, att_feats_folder):
    coco_set = CocoDataset(
        image_ids_path = image_ids_path, 
        input_seq = None, 
        target_seq = None, 
        gv_feat_path = gv_feat_path, 
        att_feats_folder = att_feats_folder,
        seq_per_img = 1, 
        max_feat_num = cfg.DATA_LOADER.MAX_FEAT
    )
    
    loader = torch.utils.data.DataLoader(
        coco_set, 
        batch_size = 1,
        shuffle = False, 
        num_workers = 0, #cfg.DATA_LOADER.NUM_WORKERS, 
        drop_last = False, 
        pin_memory = cfg.DATA_LOADER.PIN_MEMORY, 
        collate_fn = sample_collate_test
    )
    return loader

def load_val(image_ids_path, valtarget_seq, gv_feat_path, att_feats_folder):
    coco_set = CocoVal(
        image_ids_path = image_ids_path, 
        input_seq = None, 
        target_seq = valtarget_seq, 
        gv_feat_path = gv_feat_path, 
        att_feats_folder = att_feats_folder,
        seq_per_img = 1, 
        max_feat_num = cfg.DATA_LOADER.MAX_FEAT
    )
    coco_set = CocoDataset(
        image_ids_path = image_ids_path, 
        input_seq = None, 
        target_seq = valtarget_seq, 
        gv_feat_path = gv_feat_path, 
        att_feats_folder = att_feats_folder,
        seq_per_img = 1, 
        max_feat_num = cfg.DATA_LOADER.MAX_FEAT
    )
    
    loader = torch.utils.data.DataLoader(
        coco_set, 
        batch_size = 1,
        shuffle = False, 
        num_workers = 0, 
        drop_last = False, 
        pin_memory = cfg.DATA_LOADER.PIN_MEMORY, 
        collate_fn = sample_collate_val
    )
    return loader

def setup_loader(coco_set):
    training_loader = load_train(coco_set)
    return training_loader
        
def setup_dataset():
    coco_set = CocoDataset(            
            image_ids_path = cfg.DATA_LOADER.VALT_ID, 
            input_seq = cfg.DATA_LOADER.INPUT_SEQ_PATH, 
            target_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH,
            gv_feat_path = cfg.DATA_LOADER.TRAIN_GV_FEAT, 
            att_feats_folder = cfg.DATA_LOADER.TRAIN_ATT_FEATS, 
            seq_per_img = cfg.DATA_LOADER.SEQ_PER_IMG,
            max_feat_num = cfg.DATA_LOADER.MAX_FEAT
        )
    return coco_set

class CrossEntropy(nn.Module):
    def __init__(self):
        super(CrossEntropy, self).__init__()
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

    def forward(self, logit, target_seq):
        logit = logit.view(-1, logit.shape[-1])
        target_seq = target_seq.view(-1)
        loss = self.criterion(logit, target_seq)
        return loss, {'CrossEntropy Loss': loss.item()}


In [36]:
import os
import sys
import numpy as np
import torch
import tqdm
import json

class EvalTest(object):
    def __init__(
        self,
        eval_ids,
        gv_feat,
        att_feats
    ):
        super(EvalTest, self).__init__()
        self.vocab = load_vocab(cfg.INFERENCE.VOCAB)

        self.eval_ids = np.array(load_ids(eval_ids))
        self.eval_loader = load_test(eval_ids, gv_feat, att_feats)

    def make_kwargs(self, indices, ids, gv_feat, att_feats, att_mask):
        kwargs = {}
        kwargs[cfg.PARAM.INDICES] = indices
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
        kwargs[cfg.PARAM.ATT_FEATS] = att_feats
        kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
        kwargs['BEAM_SIZE'] = cfg.INFERENCE.BEAM_SIZE
        kwargs['GREEDY_DECODE'] = cfg.INFERENCE.GREEDY_DECODE
        return kwargs
        
    def __call__(self, model, device):
        model.eval()

        results = []
        with torch.no_grad():
            for batch_idx, (indices, gv_feat, att_feats, att_mask) in tqdm.tqdm(enumerate(self.eval_loader)):
                ids = self.eval_ids[indices]
                gv_feat = gv_feat.to(device)
                att_feats = att_feats.to(device)
                att_mask = att_mask.to(device)

                kwargs = self.make_kwargs(indices, ids, gv_feat, att_feats, att_mask)
                if kwargs['BEAM_SIZE'] > 1:
                    seq, logprob, all_logprobs = model.module.decode_beam(**kwargs)
                else:
                    seq, logprob, all_logprobs, output = model.module.decode(**kwargs)
                    
                sents = decode_sequence(self.vocab, seq.data)
    
                for sid, sent in enumerate(sents):
                    result = {cfg.INFERENCE.ID_KEY: ids[sid], cfg.INFERENCE.CAP_KEY: sent}
                    results.append(result)
        
        model.train()
        return results

In [37]:
import os
import sys
import numpy as np
import torch
import tqdm
import json

class Evaler(object):
    def __init__(
        self,
        eval_ids,
        valtarget_seq,
        gv_feat,
        att_feats
    ):
        super(Evaler, self).__init__()
        self.vocab = load_vocab(cfg.INFERENCE.VOCAB)

        self.eval_ids = np.array(load_ids(eval_ids))
        self.eval_loader = load_val(eval_ids, valtarget_seq, gv_feat, att_feats)
        

    def make_kwargs(self, indices, ids, valtarget_seq, gv_feat, att_feats, att_mask):
        kwargs = {}
        kwargs[cfg.PARAM.INDICES] = indices
        kwargs[cfg.PARAM.TARGET_SENT] = valtarget_seq
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
        kwargs[cfg.PARAM.ATT_FEATS] = att_feats
        kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
        kwargs['BEAM_SIZE'] = 1 #cfg.INFERENCE.BEAM_SIZE
        kwargs['GREEDY_DECODE'] = cfg.INFERENCE.GREEDY_DECODE
        return kwargs
        
    def __call__(self, model, device):
        model.eval()
        total_loss = 0.0
        num_batches = 0

        results = []
        with torch.no_grad():
            for batch_idx, (indices, valtarget_seq, gv_feat, att_feats, att_mask) in tqdm.tqdm(enumerate(self.eval_loader)):
                ids = self.eval_ids[indices]
                gv_feat = gv_feat.to(device)
                att_feats = att_feats.to(device)
                att_mask = att_mask.to(device)
                valtarget_seq = valtarget_seq.to(device)

                kwargs = self.make_kwargs(indices, ids, valtarget_seq, gv_feat, att_feats, att_mask)
                if kwargs['BEAM_SIZE'] > 1:
                    seq, logprob, all_logprobs = model.module.decode_beam(**kwargs)
                else:
                    seq, logprob, all_logprobs, outputs = model.module.decode(**kwargs)

                target_seq = kwargs[cfg.PARAM.TARGET_SENT]

                if target_seq.dtype != torch.long:
                    target_seq = target_seq.long()
    
                loss, loss_info = CrossEntropy()(outputs, target_seq)

                total_loss += loss.item()
                num_batches += 1

        avg_loss = total_loss / num_batches if num_batches > 0 else 0.
        model.train()
        return avg_loss

In [38]:
import torch
import torch.nn as nn
import time


def display(iteration, data_time, batch_time, losses, loss_info):
    info_str = ' (DataTime/BatchTime: {:.3f}/{:.3f}) losses = {:.5f}'.format(data_time.avg, batch_time.avg, losses.avg)
    print('Iteration ' + str(iteration) + info_str + ', lr = ' + str(optim.get_lr()))
    
    for name in sorted(loss_info):
        print('  ' + name + ' = ' + str(loss_info[name]))
    
    data_time.reset()
    batch_time.reset()
    losses.reset()

def make_kwargs(indices, input_seq, target_seq, gv_feat, att_feats, att_mask):
    seq_mask = (input_seq > 0).type(torch.cuda.LongTensor)
    seq_mask[:,0] += 1
    seq_mask_sum = seq_mask.sum(-1)
    max_len = int(seq_mask_sum.max())
    input_seq = input_seq[:, 0:max_len].contiguous()
    target_seq = target_seq[:, 0:max_len].contiguous()

    kwargs = {
        cfg.PARAM.INDICES: indices,
        cfg.PARAM.INPUT_SENT: input_seq,
        cfg.PARAM.TARGET_SENT: target_seq,
        cfg.PARAM.GLOBAL_FEAT: gv_feat,
        cfg.PARAM.ATT_FEATS: att_feats,
        cfg.PARAM.ATT_FEATS_MASK: att_mask
    }
    return kwargs

def forward(model, kwargs):
    logit = model(**kwargs)
    
    target_seq = kwargs[cfg.PARAM.TARGET_SENT]
    
    if target_seq.dtype != torch.long:
        target_seq = target_seq.long()
    
    loss, loss_info = CrossEntropy()(logit, target_seq)
    return loss, loss_info

In [None]:
import matplotlib.pyplot as plt

class Trainer:
    def __init__(self, model, optimizer, dataloader, device, save_path):
        random.seed(cfg.SEED)
        torch.manual_seed(cfg.SEED)
        torch.cuda.manual_seed_all(cfg.SEED)
        self.model = model
        self.optimizer = optimizer
        self.dataloader = dataloader
        self.device = device
        self.save_path = save_path
        self.train_losses = []
        self.valid_losses = []
        self.data_time = AverageMeter()
        self.batch_time = AverageMeter()
        self.losses = AverageMeter()

    def evaluate(self, model):
        self.model = model
        self.model.eval()
        self.evaler = Evaler(
            eval_ids = cfg.DATA_LOADER.VALT_ID,
            valtarget_seq = cfg.DATA_LOADER.TARGET_SEQ_PATH,
            gv_feat = cfg.DATA_LOADER.TRAIN_GV_FEAT,
            att_feats = cfg.DATA_LOADER.TRAIN_ATT_FEATS)
        avg_loss = self.evaler(self.model, self.device)
    
        return avg_loss
    
    def save_model(self, epoch):
        save_file_path = os.path.join(self.save_path, f'model4_epoch_{epoch + 1}.pth')
        torch.save(self.model.state_dict(), save_file_path)

    def scheduled_sampling(self, epoch):
        if epoch > cfg.TRAIN.SCHEDULED_SAMPLING.START:
            frac = (epoch - cfg.TRAIN.SCHEDULED_SAMPLING.START) // cfg.TRAIN.SCHEDULED_SAMPLING.INC_EVERY
            ss_prob = min(cfg.TRAIN.SCHEDULED_SAMPLING.INC_PROB * frac, cfg.TRAIN.SCHEDULED_SAMPLING.MAX_PROB)
            self.model.module.ss_prob = ss_prob

    def plot_losses(self):
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(self.train_losses) + 1), self.train_losses, 'b-', label='Training Loss')
        plt.plot(range(1, len(self.valid_losses) + 1), self.valid_losses, 'r-', label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Model 4 Training and Validation Loss')
        plt.legend()
        plt.grid(True)
        plt.show()

    def train(self, num_epoch):
        self.model = torch.nn.DataParallel(self.model).to(device)
        self.model.train()
        optimizer = self.optimizer
        optimizer.zero_grad()
        scheduler1 = optimizer.scheduler1()
        scheduler2 = optimizer.scheduler2()
        scheduler3 = optimizer.scheduler3()

        total_batches = len(self.dataloader)

        for epoch in range(num_epoch):
            print(f"Epoch {epoch + 1}/{num_epoch}")
            training_loader = self.dataloader

            training_loss = 0.0
            epoch_start_time = time.time()
            
            for batch_idx, (indices, input_seq, target_seq, gv_feat, att_feats, att_mask) in enumerate(training_loader):
                start_time = time.time() 

                input_seq = input_seq.to(device)
                target_seq = target_seq.to(device)
                gv_feat = gv_feat.to(device)
                att_feats = att_feats.to(device)
                att_mask = att_mask.to(device)

                kwargs = make_kwargs(indices, input_seq, target_seq, gv_feat, att_feats, att_mask)
                loss, loss_info = forward(self.model, kwargs)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                training_loss += loss.item()

                batch_time = time.time() - start_time

                progress = (batch_idx + 1) / total_batches
                progress_bar = ">" * int(30 * progress) + "-" * (30 - int(30 * progress))
                print(f"\r[{progress_bar}] {batch_idx + 1}/{total_batches} - Loss: {loss.item():.5f} - Batch Time: {batch_time:.3f} sec", end='')

                self.losses.reset()

            epoch_time = time.time() - epoch_start_time
            if epoch > num_epoch - 10:
                self.save_model(epoch)
            training_loss /= len(training_loader)
            self.train_losses.append(training_loss)
            valid_loss = self.evaluate(self.model)
            self.valid_losses.append(valid_loss)
            scheduler1.step()
            scheduler2.step()
            scheduler3.step(valid_loss)
            print(f" - Epoch Time: {epoch_time:.3f} sec - Train Loss: {training_loss:.5f} - Validation Loss: {valid_loss:.5f}")

model = XLAN()
optimizer = Optimizer(model)
dataloader = setup_loader(setup_dataset())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
save_path = '/kaggle/working/model/model4'

trainer = Trainer(model, optimizer, dataloader, device, save_path)
num_epoch = 100
trainer.train(num_epoch)
trainer.plot_losses()

In [279]:
class Tester:
    def __init__(self, model, checkpoint, evaler, device, save_path):
        self.model = model
        self.checkpoint = checkpoint
        self.evaler = evaler
        self.device = device
        self.save_path = save_path
        
    def eval(self):
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.load_state_dict(self.checkpoint)
        result = self.evaler(self.model, self.device)
        with open(self.save_path, "w") as json_file:
            json.dump(result, json_file)
            
#output_dir = '/kaggle/working/save_predict'
#os.makedirs(output_dir, exist_ok=True)

model = XLAN()
checkpoint = torch.load('/kaggle/input/model-x-lan/Model/Model 1/model1_epoch_192.pth')
evaler = EvalTest(
            eval_ids = cfg.DATA_LOADER.TEST_ID,
            gv_feat = cfg.DATA_LOADER.TEST_GV_FEAT,
            att_feats = cfg.DATA_LOADER.TEST_ATT_FEATS)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
save_path = os.path.join(output_dir, 'caption_predictmodel1.json')

tester = Tester(model, checkpoint, evaler, device, save_path)
tester.eval()

113it [00:06, 17.18it/s]


In [280]:
import json

actual = []
predict = []

with open('/kaggle/working/save_predict/caption_predictmodel1.json', 'r') as f:
    results = json.load(f)

for result in results:
    image_filename = result[cfg.INFERENCE.ID_KEY]
    caption_pred = result[cfg.INFERENCE.CAP_KEY]
    caption_act = images_captions_dict[image_filename]
    predict.append(caption_pred)
    caption_act = remove_start_and_end_tokens(caption_act)
    actual.append(caption_act)

print(len(predict), len(actual))

113 113


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

i = 0

for result in results:
    image_filename = result[cfg.INFERENCE.ID_KEY]
    image_path = os.path.join('/kaggle/input/indonesian-traffic-violation-on-motorcycle/image_crop/image crop/image', image_filename + '.png')
    image = mpimg.imread(image_path)
    plt.imshow(image)
    plt.axis('off')
    plt.show() 
    print(image_filename)
    print(f"predict: {predict[i]}")
    print(f"actual: {actual[i]}")
    print("----------------------")
    i += 1

In [281]:
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction
from nltk.util import ngrams
from nltk.metrics import *
from nltk.probability import FreqDist 

tokenized_predict = []
tokenized_actual = []

for sentence in predict:
    tokens = sentence.split()
    tokenized_predict.append(tokens)

for sentence in actual:
    tokens = sentence.split()
    tokenized_actual.append(tokens)

actual_list = [[act] for act in tokenized_actual]

bleu1_score = corpus_bleu(actual_list, tokenized_predict, weights=(1.0, 0, 0, 0))
print("BLEU-1: %f" % bleu1_score)
bleu2_score = corpus_bleu(actual_list, tokenized_predict, weights=(0.5, 0.5, 0, 0), smoothing_function=SmoothingFunction().method1)
print("BLEU-2: %f" % bleu2_score)
bleu3_score = corpus_bleu(actual_list, tokenized_predict, weights=(0.33, 0.33, 0.33, 0), smoothing_function=SmoothingFunction().method1)
print("BLEU-3: %f" % bleu3_score)
bleu4_score = corpus_bleu(actual_list, tokenized_predict, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
print("BLEU-4: %f" % bleu4_score)

avg_bleu = (bleu1_score+bleu2_score+bleu3_score+bleu4_score)/4
print(f"Average: {avg_bleu}")

def lcsseq(x, y):
    m = len(x)
    n = len(y)
    dp = [[""] * (n + 1) for _ in range(m + 1)]

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if x[i - 1] == y[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + " " + x[i - 1]
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1], key=len)

    return dp[m][n].split()

def rouge_l(evaluated, reference):
    evaluated_tokens = evaluated.split()
    reference_tokens = reference.split()
    lcs = lcsseq(evaluated_tokens, reference_tokens)
    lcs_length = len(lcs)

    reference_length = len(reference_tokens)

    if reference_length == 0:
        return 0
    else:
        return lcs_length / reference_length


rouge_l_scores = []

for p, a in zip(predict, actual):
    rouge_l_scores.append(rouge_l(p, a))

avg_rouge_l = sum(rouge_l_scores) / len(rouge_l_scores)

print(f"ROUGE-L: {avg_rouge_l}")

BLEU-1: 0.655956
BLEU-2: 0.563424
BLEU-3: 0.473858
BLEU-4: 0.395459
Average: 0.5221742257749868
ROUGE-L: 0.6511320508977976


In [81]:
from nltk.util import ngrams
from nltk.metrics import *
from nltk.probability import FreqDist 

def lcsseq(x, y):
    m = len(x)
    n = len(y)
    dp = [[""] * (n + 1) for _ in range(m + 1)]

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if x[i - 1] == y[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + " " + x[i - 1]
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1], key=len)

    return dp[m][n].split()

def rouge_l(evaluated, reference):
    evaluated_tokens = evaluated.split()
    reference_tokens = reference.split()
    lcs = lcsseq(evaluated_tokens, reference_tokens)
    lcs_length = len(lcs)

    reference_length = len(reference_tokens)

    if reference_length == 0:
        return 0
    else:
        return lcs_length / reference_length


rouge_l_scores = []

for p, a in zip(predict, actual):
    rouge_l_scores.append(rouge_l(p, a))

avg_rouge_l = sum(rouge_l_scores) / len(rouge_l_scores)

print(f"Average ROUGE-L score: {avg_rouge_l}")


Average ROUGE-1 score: 0.6554760699997555
Average ROUGE-L score: 0.6271321670983303
