In [1]:
!git clone "https://github.com/zhjohnchan/R2GenCMN.git"

Cloning into 'R2GenCMN'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 88 (delta 14), reused 23 (delta 8), pack-reused 54[K
Receiving objects: 100% (88/88), 70.19 MiB | 15.55 MiB/s, done.
Resolving deltas: 100% (19/19), done.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# utils.py

import numpy as np
import cv2
import torch


def penalty_builder(penalty_config):
    if penalty_config == '':
        return lambda x, y: y
    pen_type, alpha = penalty_config.split('_')
    alpha = float(alpha)
    if pen_type == 'wu':
        return lambda x, y: length_wu(x, y, alpha)
    if pen_type == 'avg':
        return lambda x, y: length_average(x, y, alpha)


def length_wu(length, logprobs, alpha=0.):
    """
    NMT length re-ranking score from
    "Google's Neural Machine Translation System" :cite:`wu2016google`.
    """

    modifier = (((5 + length) ** alpha) /
                ((5 + 1) ** alpha))
    return logprobs / modifier


def length_average(length, logprobs, alpha=0.):
    """
    Returns the average probability of tokens in a sequence.
    """
    return logprobs / length


def split_tensors(n, x):
    if torch.is_tensor(x):
        assert x.shape[0] % n == 0
        x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
    elif type(x) is list or type(x) is tuple:
        x = [split_tensors(n, _) for _ in x]
    elif x is None:
        x = [None] * n
    return x


def repeat_tensors(n, x):
    """
    For a tensor of size Bx..., we repeat it n times, and make it Bnx...
    For collections, do nested repeat
    """
    if torch.is_tensor(x):
        x = x.unsqueeze(1)  # Bx1x...
        x = x.expand(-1, n, *([-1] * len(x.shape[2:])))  # Bxnx...
        x = x.reshape(x.shape[0] * n, *x.shape[2:])  # Bnx...
    elif type(x) is list or type(x) is tuple:
        x = [repeat_tensors(n, _) for _ in x]
    return x


def generate_heatmap(image, weights):
    image = image.transpose(1, 2, 0)
    height, width, _ = image.shape
    weights = weights.reshape(int(weights.shape[0] ** 0.5), int(weights.shape[0] ** 0.5))
    weights = weights - np.min(weights)
    weights = weights / np.max(weights)
    weights = cv2.resize(weights, (width, height))
    weights = np.uint8(255 * weights)
    heatmap = cv2.applyColorMap(weights, cv2.COLORMAP_JET)
    result = heatmap * 0.5 + image * 0.5
    return result

In [4]:
# caption_model

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F

# import modules.utils as utils


class CaptionModel(nn.Module):
    def __init__(self):
        super(CaptionModel, self).__init__()

    # implements beam search
    # calls beam_step and returns the final set of beams
    # augments log-probabilities with diversity terms when number of groups > 1

    def forward(self, *args, **kwargs):
        mode = kwargs.get('mode', 'forward')
        if 'mode' in kwargs:
            del kwargs['mode']
        return getattr(self, '_' + mode)(*args, **kwargs)

    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobs = logprobs.clone()
            batch_size = beam_seq_table[0].shape[0]

            if divm > 0:
                change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
                for prev_choice in range(divm):
                    prev_decisions = beam_seq_table[prev_choice][:, :, local_time]  # Nxb
                    for prev_labels in range(bdash):
                        change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1),
                                            change.new_ones(batch_size, 1))

                if local_time == 0:
                    logprobs = logprobs - change * diversity_lambda
                else:
                    logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda

            return logprobs, unaug_logprobs

        # does one step of classical beam search

        def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
            # INPUTS:
            # logprobs: probabilities augmented after diversity N*bxV
            # beam_size: obvious
            # t        : time instant
            # beam_seq : tensor contanining the beams
            # beam_seq_logprobs: tensor contanining the beam logprobs
            # beam_logprobs_sum: tensor contanining joint logprobs
            # OUPUTS:
            # beam_seq : tensor containing the word indices of the decoded captions Nxbxl
            # beam_seq_logprobs : log-probability of each decision made, NxbxlxV
            # beam_logprobs_sum : joint log-probability of each beam Nxb

            batch_size = beam_logprobs_sum.shape[0]
            vocab_size = logprobs.shape[-1]
            logprobs = logprobs.reshape(batch_size, -1, vocab_size)  # NxbxV
            if t == 0:
                assert logprobs.shape[1] == 1
                beam_logprobs_sum = beam_logprobs_sum[:, :1]
            candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs  # beam_logprobs_sum Nxb logprobs is NxbxV
            ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
            ys, ix = ys[:, :beam_size], ix[:, :beam_size]
            beam_ix = ix // vocab_size  # Nxb which beam
            selected_ix = ix % vocab_size  # Nxb # which world
            state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(
                -1)  # N*b which in Nxb beams

            if t > 0:
                # gather according to beam_ix
                assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) ==
                        beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
                beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))

                beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(
                    beam_seq_logprobs))

            beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1)  # beam_seq Nxbxl
            beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
                                logprobs.reshape(batch_size, -1).gather(1, ix)
            assert (beam_logprobs_sum == ys).all()
            _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
            beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1,
                                                                                      beam_ix.unsqueeze(-1).expand(-1,
                                                                                                                   -1,
                                                                                                                   vocab_size))  # NxbxV
            assert (_tmp_beam_logprobs == beam_logprobs).all()
            beam_seq_logprobs = torch.cat([
                beam_seq_logprobs,
                beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)

            new_state = [None for _ in state]
            for _ix in range(len(new_state)):
                #  copy over state in previous beam q to new beam at vix
                new_state[_ix] = state[_ix][:, state_ix]
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state

        # Start diverse_beam_search
        opt = kwargs['opt']
        temperature = opt.get('temperature', 1)  # This should not affect beam search, but will affect dbs
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        suppress_UNK = opt.get('suppress_UNK', 0)
        length_penalty = penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        batch_size = init_logprobs.shape[0]
        device = init_logprobs.device
        # INITIALIZATIONS
        beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
        beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in
                                   range(group_size)]
        beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
        state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
        logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
        # END INIT

        # Chunk elements in the args
        args = list(args)
        args = split_tensors(group_size, args)  # For each arg, turn (Bbg)x... to (Bb)x(g)x...
        if self.__class__.__name__ == 'AttEnsemble':
            args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
                    range(group_size)]  # group_name, arg_name, model_name
        else:
            args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]

        for t in range(self.max_seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.max_seq_length + divm - 1:
                    # add diversity
                    logprobs = logprobs_table[divm]
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobs.scatter_(1, beam_seq_table[divm][:, :, t - divm - 1].reshape(-1, 1).to(device),
                                          float('-inf'))
                    # suppress UNK tokens in the decoding
                    if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1) - 1)] == 'UNK':
                        logprobs[:, logprobs.size(1) - 1] = logprobs[:, logprobs.size(1) - 1] - 1000
                        # diversity is added here
                    # the function directly modifies the logprobs values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    logprobs, unaug_logprobs = add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash)

                    # infer new beams
                    beam_seq_table[divm], \
                    beam_seq_logprobs_table[divm], \
                    beam_logprobs_sum_table[divm], \
                    state_table[divm] = beam_step(logprobs,
                                                  unaug_logprobs,
                                                  bdash,
                                                  t - divm,
                                                  beam_seq_table[divm],
                                                  beam_seq_logprobs_table[divm],
                                                  beam_logprobs_sum_table[divm],
                                                  state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for b in range(batch_size):
                        is_end = beam_seq_table[divm][b, :, t - divm] == self.eos_idx
                        assert beam_seq_table[divm].shape[-1] == t - divm + 1
                        if t == self.max_seq_length + divm - 1:
                            is_end.fill_(1)
                        for vix in range(bdash):
                            if is_end[vix]:
                                final_beam = {
                                    'seq': beam_seq_table[divm][b, vix].clone(),
                                    'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
                                    'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
                                    'p': beam_logprobs_sum_table[divm][b, vix].item()
                                }
                                final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
                                done_beams_table[b][divm].append(final_beam)
                        beam_logprobs_sum_table[divm][b, is_end] -= 1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][:, :, t - divm].reshape(-1)
                    logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
                            args[divm] + [state_table[divm]]))
                    logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)

        # all beams are sorted by their log-probabilities
        done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
                            for b in range(batch_size)]
        done_beams = [sum(_, []) for _ in done_beams_table]
        return done_beams

    def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[
                            prev_labels]] - diversity_lambda
            return unaug_logprobsf

        # does one step of classical beam search

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
            # INPUTS:
            # logprobsf: probabilities augmented after diversity
            # beam_size: obvious
            # t        : time instant
            # beam_seq : tensor contanining the beams
            # beam_seq_logprobs: tensor contanining the beam logprobs
            # beam_logprobs_sum: tensor contanining joint logprobs
            # OUPUTS:
            # beam_seq : tensor containing the word indices of the decoded captions
            # beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            # beam_logprobs_sum : joint log-probability of each beam

            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols):  # for each column (word, essentially)
                for q in range(rows):  # for each beam expansion
                    # compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
                    candidates.append({'c': ix[q, c], 'q': q, 'p': candidate_logprob, 'r': unaug_logprobsf[q]})
            candidates = sorted(candidates, key=lambda x: -x['p'])

            new_state = [_.clone() for _ in state]
            # beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
                # we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                # fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
                # rearrange recurrent states
                for state_ix in range(len(new_state)):
                    #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v['q']]  # dimension one is time step
                # append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c']  # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r']  # the raw logprob here
                beam_logprobs_sum[vix] = v['p']  # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates

        # Start diverse_beam_search
        opt = kwargs['opt']
        temperature = opt.get('temperature', 1)  # This should not affect beam search, but will affect dbs
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        suppress_UNK = opt.get('suppress_UNK', 0)
        length_penalty = penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        # INITIALIZATIONS
        beam_seq_table = [torch.LongTensor(self.max_seq_length, bdash).zero_() for _ in range(group_size)]
        beam_seq_logprobs_table = [torch.FloatTensor(self.max_seq_length, bdash, self.vocab_size + 1).zero_() for _ in
                                   range(group_size)]
        beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[] for _ in range(group_size)]
        # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
        state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # END INIT

        # Chunk elements in the args
        args = list(args)
        if self.__class__.__name__ == 'AttEnsemble':
            args = [[_.chunk(group_size) if _ is not None else [None] * group_size for _ in args_] for args_ in
                    args]  # arg_name, model_name, group_name
            args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in
                    range(group_size)]  # group_name, arg_name, model_name
        else:
            args = [_.chunk(group_size) if _ is not None else [None] * group_size for _ in args]
            args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]

        for t in range(self.max_seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.max_seq_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].float()
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobsf.scatter_(1, beam_seq_table[divm][t - divm - 1].unsqueeze(1).cuda(), float('-inf'))
                    # suppress UNK tokens in the decoding
                    if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1) - 1)] == 'UNK':
                        logprobsf[:, logprobsf.size(1) - 1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
                        # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    unaug_logprobsf = add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash)

                    # infer new beams
                    beam_seq_table[divm], \
                    beam_seq_logprobs_table[divm], \
                    beam_logprobs_sum_table[divm], \
                    state_table[divm], \
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t - divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][t - divm, vix] == self.eos_idx or t == self.max_seq_length + divm - 1:
                            final_beam = {
                                'seq': beam_seq_table[divm][:, vix].clone(),
                                'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
                                'p': beam_logprobs_sum_table[divm][vix].item()
                            }
                            final_beam['p'] = length_penalty(t - divm + 1, final_beam['p'])
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][t - divm]
                    logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(
                            args[divm] + [state_table[divm]]))
                    logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)

        # all beams are sorted by their log-probabilities
        done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
        done_beams = sum(done_beams_table, [])
        return done_beams

    def sample_next_word(self, logprobs, sample_method, temperature):
        if sample_method == 'greedy':
            sampleLogprobs, it = torch.max(logprobs.data, 1)
            it = it.view(-1).long()
        elif sample_method == 'gumbel':  # gumbel softmax
            def sample_gumbel(shape, eps=1e-20):
                U = torch.rand(shape).cuda()
                return -torch.log(-torch.log(U + eps) + eps)

            def gumbel_softmax_sample(logits, temperature):
                y = logits + sample_gumbel(logits.size())
                return F.log_softmax(y / temperature, dim=-1)

            _logprobs = gumbel_softmax_sample(logprobs, temperature)
            _, it = torch.max(_logprobs.data, 1)
            sampleLogprobs = logprobs.gather(1, it.unsqueeze(1))  # gather the logprobs at sampled positions
        else:
            logprobs = logprobs / temperature
            if sample_method.startswith('top'):  # topk sampling
                top_num = float(sample_method[3:])
                if 0 < top_num < 1:
                    # nucleus sampling from # The Curious Case of Neural Text Degeneration
                    probs = F.softmax(logprobs, dim=1)
                    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
                    _cumsum = sorted_probs.cumsum(1)
                    mask = _cumsum < top_num
                    mask = torch.cat([torch.ones_like(mask[:, :1]), mask[:, :-1]], 1)
                    sorted_probs = sorted_probs * mask.float()
                    sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
                    logprobs.scatter_(1, sorted_indices, sorted_probs.log())
                else:
                    the_k = int(top_num)
                    tmp = torch.empty_like(logprobs).fill_(float('-inf'))
                    topk, indices = torch.topk(logprobs, the_k, dim=1)
                    tmp = tmp.scatter(1, indices, topk)
                    logprobs = tmp
            it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
            sampleLogprobs = logprobs.gather(1, it.unsqueeze(1))  # gather the logprobs at sampled positions
        return it, sampleLogprobs


In [5]:
# att_model

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence

# import modules.utils as utils
# from modules.caption_model import CaptionModel


def sort_pack_padded_sequence(input, lengths):
    sorted_lengths, indices = torch.sort(lengths, descending=True)
    tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
    inv_ix = indices.clone()
    inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
    return tmp, inv_ix


def pad_unsort_packed_sequence(input, inv_ix):
    tmp, _ = pad_packed_sequence(input, batch_first=True)
    tmp = tmp[inv_ix]
    return tmp


def pack_wrapper(module, att_feats, att_masks):
    if att_masks is not None:
        packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
        return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
    else:
        return module(att_feats)


class AttModel(CaptionModel):
    def __init__(self, args, tokenizer):
        super(AttModel, self).__init__()
        self.args = args
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer.idx2token)
        self.input_encoding_size = args['d_model']
        self.rnn_size = args['d_ff']
        self.num_layers = args['num_layers']
        self.drop_prob_lm = args['drop_prob_lm']
        self.max_seq_length = args['max_seq_length']
        self.att_feat_size = args['d_vf']
        self.att_hid_size = args['d_model']

        self.bos_idx = args['bos_idx']
        self.eos_idx = args['eos_idx']
        self.pad_idx = args['pad_idx']

        self.use_bn = args['use_bn']

        self.embed = lambda x: x
        self.fc_embed = lambda x: x
        self.att_embed = nn.Sequential(*(
                ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) +
                (nn.Linear(self.att_feat_size, self.input_encoding_size),
                 nn.ReLU(),
                 nn.Dropout(self.drop_prob_lm)) +
                ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ())))

    def clip_att(self, att_feats, att_masks):
        # Clip the length of att_masks and att_feats to the maximum length
        if att_masks is not None:
            max_len = att_masks.data.long().sum(1).max()
            att_feats = att_feats[:, :max_len].contiguous()
            att_masks = att_masks[:, :max_len].contiguous()
        return att_feats, att_masks

    def _prepare_feature(self, fc_feats, att_feats, att_masks):
        att_feats, att_masks = self.clip_att(att_feats, att_masks)

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats)

        return fc_feats, att_feats, p_att_feats, att_masks

    def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
        # 'it' contains a word index
        xt = self.embed(it)

        output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
        if output_logsoftmax:
            logprobs = F.log_softmax(self.logit(output), dim=1)
        else:
            logprobs = self.logit(output)

        return logprobs, state

    def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        sample_n = opt.get('sample_n', 10)
        # when sample_n == beam_size then each beam is a sample.
        assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
        batch_size = fc_feats.size(0)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
        seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
        # lets process every image independently for now, for simplicity

        self.done_beams = [[] for _ in range(batch_size)]

        state = self.init_hidden(batch_size)

        # first step, feed bos
        it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
        logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks =repeat_tensors(beam_size,
                                                                                  [p_fc_feats, p_att_feats,
                                                                                   pp_att_feats, p_att_masks]
                                                                                  )
        self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
        for k in range(batch_size):
            if sample_n == beam_size:
                for _n in range(sample_n):
                    seq_len = self.done_beams[k][_n]['seq'].shape[0]
                    seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq']
                    seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps']
            else:
                seq_len = self.done_beams[k][0]['seq'].shape[0]
                seq[k, :seq_len] = self.done_beams[k][0]['seq']  # the first beam has highest cumulative score
                seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
        # return the samples and their log likelihoods
        return seq, seqLogprobs

    def _sample(self, fc_feats, att_feats, att_masks=None, update_opts={}):
        opt = self.args
        opt.update(**update_opts)

        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        sample_n = int(opt.get('sample_n', 1))
        group_size = opt.get('group_size', 1)
        output_logsoftmax = opt.get('output_logsoftmax', 1)
        decoding_constraint = opt.get('decoding_constraint', 0)
        block_trigrams = opt.get('block_trigrams', 0)
        if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
            return self._sample_beam(fc_feats, att_feats, att_masks, opt)
        if group_size > 1:
            return self._diverse_sample(fc_feats, att_feats, att_masks, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size * sample_n)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        if sample_n > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(sample_n,
                                                                                      [p_fc_feats, p_att_feats,
                                                                                       pp_att_feats, p_att_masks]
                                                                                      )

        trigrams = []  # will be a list of batch_size dictionaries

        seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1)
        for t in range(self.max_seq_length + 1):
            if t == 0:  # input <bos>
                it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long)

            logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state,
                                                      output_logsoftmax=output_logsoftmax)

            if decoding_constraint and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
                logprobs = logprobs + tmp

            # Mess with trigrams
            # Copy from https://github.com/lukemelas/image-paragraph-captioning
            if block_trigrams and t >= 3:
                # Store trigram generated at last step
                prev_two_batch = seq[:, t - 3:t - 1]
                for i in range(batch_size):  # = seq.size(0)
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    current = seq[i][t - 1]
                    if t == 3:  # initialize
                        trigrams.append({prev_two: [current]})  # {LongTensor: list containing 1 int}
                    elif t > 3:
                        if prev_two in trigrams[i]:  # add to list
                            trigrams[i][prev_two].append(current)
                        else:  # create list
                            trigrams[i][prev_two] = [current]
                # Block used trigrams at next step
                prev_two_batch = seq[:, t - 2:t]
                mask = torch.zeros(logprobs.size(), requires_grad=False).cuda()  # batch_size x vocab_size
                for i in range(batch_size):
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    if prev_two in trigrams[i]:
                        for j in trigrams[i][prev_two]:
                            mask[i, j] += 1
                # Apply mask to log probs
                # logprobs = logprobs - (mask * 1e9)
                alpha = 2.0  # = 4
                logprobs = logprobs + (mask * -0.693 * alpha)  # ln(1/2) * alpha (alpha -> infty works best)

            # sample the next word
            if t == self.max_seq_length:  # skip if we achieve maximum length
                break
            it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)

            # stop when all finished
            if t == 0:
                unfinished = it != self.eos_idx
            else:
                it[~unfinished] = self.pad_idx  # This allows eos_idx not being overwritten to 0
                logprobs = logprobs * unfinished.unsqueeze(1).float()
                unfinished = unfinished * (it != self.eos_idx)
            seq[:, t] = it
            seqLogprobs[:, t] = logprobs
            # quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break

        return seq, seqLogprobs

    def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):

        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        block_trigrams = opt.get('block_trigrams', 0)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        trigrams_table = [[] for _ in range(group_size)]  # will be a list of batch_size dictionaries

        seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in
                     range(group_size)]
        seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)]
        state_table = [self.init_hidden(batch_size) for _ in range(group_size)]

        for tt in range(self.max_seq_length + group_size):
            for divm in range(group_size):
                t = tt - divm
                seq = seq_table[divm]
                seqLogprobs = seqLogprobs_table[divm]
                trigrams = trigrams_table[divm]
                if t >= 0 and t <= self.max_seq_length - 1:
                    if t == 0:  # input <bos>
                        it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
                    else:
                        it = seq[:, t - 1]  # changed

                    logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats,
                                                                          p_att_masks, state_table[divm])  # changed
                    logprobs = F.log_softmax(logprobs / temperature, dim=-1)

                    # Add diversity
                    if divm > 0:
                        unaug_logprobs = logprobs.clone()
                        for prev_choice in range(divm):
                            prev_decisions = seq_table[prev_choice][:, t]
                            logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda

                    if decoding_constraint and t > 0:
                        tmp = logprobs.new_zeros(logprobs.size())
                        tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf'))
                        logprobs = logprobs + tmp

                    # Mess with trigrams
                    if block_trigrams and t >= 3:
                        # Store trigram generated at last step
                        prev_two_batch = seq[:, t - 3:t - 1]
                        for i in range(batch_size):  # = seq.size(0)
                            prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                            current = seq[i][t - 1]
                            if t == 3:  # initialize
                                trigrams.append({prev_two: [current]})  # {LongTensor: list containing 1 int}
                            elif t > 3:
                                if prev_two in trigrams[i]:  # add to list
                                    trigrams[i][prev_two].append(current)
                                else:  # create list
                                    trigrams[i][prev_two] = [current]
                        # Block used trigrams at next step
                        prev_two_batch = seq[:, t - 2:t]
                        mask = torch.zeros(logprobs.size(), requires_grad=False).cuda()  # batch_size x vocab_size
                        for i in range(batch_size):
                            prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                            if prev_two in trigrams[i]:
                                for j in trigrams[i][prev_two]:
                                    mask[i, j] += 1
                        # Apply mask to log probs
                        # logprobs = logprobs - (mask * 1e9)
                        alpha = 2.0  # = 4
                        logprobs = logprobs + (mask * -0.693 * alpha)  # ln(1/2) * alpha (alpha -> infty works best)

                    it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)

                    # stop when all finished
                    if t == 0:
                        unfinished = it != self.eos_idx
                    else:
                        unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx
                        it[~unfinished] = self.pad_idx
                        unfinished = unfinished & (it != self.eos_idx)  # changed
                    seq[:, t] = it
                    seqLogprobs[:, t] = sampleLogprobs.view(-1)

        return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table,
                                                                                           1).reshape(
            batch_size * group_size, -1)


In [6]:
# base_cmn.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# from .att_model import pack_wrapper, AttModel


def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


def memory_querying_responding(query, key, value, mask=None, dropout=None, topk=32):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    selected_scores, idx = scores.topk(topk)
    dummy_value = value.unsqueeze(2).expand(idx.size(0), idx.size(1), idx.size(2), value.size(-2), value.size(-1))
    dummy_idx = idx.unsqueeze(-1).expand(idx.size(0), idx.size(1), idx.size(2), idx.size(3), value.size(-1))
    selected_value = torch.gather(dummy_value, 3, dummy_idx)
    p_attn = F.softmax(selected_scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn.unsqueeze(3), selected_value).squeeze(3), p_attn


class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, cmn):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.cmn = cmn

    def forward(self, src, tgt, src_mask, tgt_mask, memory_matrix):
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask, memory_matrix=memory_matrix)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask, past=None, memory_matrix=None):
        embeddings = self.tgt_embed(tgt)

        # Memory querying and responding for textual features
        dummy_memory_matrix = memory_matrix.unsqueeze(0).expand(embeddings.size(0), memory_matrix.size(0), memory_matrix.size(1))
        responses = self.cmn(embeddings, dummy_memory_matrix, dummy_memory_matrix)
        embeddings = embeddings + responses
        # Memory querying and responding for textual features

        return self.decoder(embeddings, memory, src_mask, tgt_mask, past=past)


class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        _x = sublayer(self.norm(x))
        if type(_x) is tuple:
            return x + self.dropout(_x[0]), _x[1]
        return x + self.dropout(_x)


class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)


class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask, past=None):
        if past is not None:
            present = [[], []]
            x = x[:, -1:]
            tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
            past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
        else:
            past = [None] * len(self.layers)
        for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
            x = layer(x, memory, src_mask, tgt_mask,
                      layer_past)
            if layer_past is not None:
                present[0].append(x[1][0])
                present[1].append(x[1][1])
                x = x[0]
        if past[0] is None:
            return self.norm(x)
        else:
            return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]


class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
        m = memory
        if layer_past is None:
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
            x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
            return self.sublayer[2](x, self.feed_forward)
        else:
            present = [None, None]
            x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
            x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
            return self.sublayer[2](x, self.feed_forward), present


class MultiThreadMemory(nn.Module):
    def __init__(self, h, d_model, dropout=0.1, topk=32):
        super(MultiThreadMemory, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        self.topk = topk

    def forward(self, query, key, value, mask=None, layer_past=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1:
            query = self.linears[0](query)
            key, value = layer_past[0], layer_past[1]
            present = torch.stack([key, value])
        else:
            query, key, value = \
                [l(x) for l, x in zip(self.linears, (query, key, value))]
        if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
            past_key, past_value = layer_past[0], layer_past[1]
            key = torch.cat((past_key, key), dim=1)
            value = torch.cat((past_value, value), dim=1)
            present = torch.stack([key, value])

        query, key, value = \
            [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for x in [query, key, value]]

        x, self.attn = memory_querying_responding(query, key, value, mask=mask, dropout=self.dropout, topk=self.topk)

        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        if layer_past is not None:
            return self.linears[-1](x), present
        else:
            return self.linears[-1](x)


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None, layer_past=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1:
            query = self.linears[0](query)
            key, value = layer_past[0], layer_past[1]
            present = torch.stack([key, value])
        else:
            query, key, value = \
                [l(x) for l, x in zip(self.linears, (query, key, value))]

        if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
            past_key, past_value = layer_past[0], layer_past[1]
            key = torch.cat((past_key, key), dim=1)
            value = torch.cat((past_value, value), dim=1)
            present = torch.stack([key, value])

        query, key, value = \
            [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for x in [query, key, value]]

        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        if layer_past is not None:
            return self.linears[-1](x), present
        else:
            return self.linears[-1](x)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class BaseCMN(AttModel):

    def make_model(self, tgt_vocab, cmn):
        c = copy.deepcopy
        attn = MultiHeadedAttention(self.num_heads, self.d_model)
        ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout)
        position = PositionalEncoding(self.d_model, self.dropout)
        model = Transformer(
            Encoder(EncoderLayer(self.d_model, c(attn), c(ff), self.dropout), self.num_layers),
            Decoder(DecoderLayer(self.d_model, c(attn), c(attn), c(ff), self.dropout), self.num_layers),
            nn.Sequential(c(position)),
            nn.Sequential(Embeddings(self.d_model, tgt_vocab), c(position)), cmn)
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        return model

    def __init__(self, args, tokenizer):
        super(BaseCMN, self).__init__(args, tokenizer)
        self.args = args
        self.num_layers = args['num_layers']
        self.d_model = args['d_model']      #the dimension of Transformer
        self.d_ff = args['d_ff']            #the dimension of FFN
        self.num_heads = args['num_heads']  #the number of heads in Transformer.
        self.dropout = args['dropout']      #the dropout rate of Transformer
        self.topk = args['topk']            #the number of k

        tgt_vocab = self.vocab_size + 1

        self.cmn = MultiThreadMemory(args['num_heads'], args['d_model'], topk=args['topk'])

        self.model = self.make_model(tgt_vocab, self.cmn)
        self.logit = nn.Linear(args['d_model'], tgt_vocab)

        self.memory_matrix = nn.Parameter(torch.FloatTensor(args['cmm_size'], args['cmm_dim']))
        nn.init.normal_(self.memory_matrix, 0, 1 / args['cmm_dim'])

    def init_hidden(self, bsz):
        return []

    def _prepare_feature(self, fc_feats, att_feats, att_masks):
        att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
        memory = self.model.encode(att_feats, att_masks)

        return fc_feats[..., :1], att_feats[..., :1], memory, att_masks

    def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
        att_feats, att_masks = self.clip_att(att_feats, att_masks)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)

        if att_masks is None:
            att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)

        # Memory querying and responding for visual features
        dummy_memory_matrix = self.memory_matrix.unsqueeze(0).expand(att_feats.size(0), self.memory_matrix.size(0), self.memory_matrix.size(1))
        responses = self.cmn(att_feats, dummy_memory_matrix, dummy_memory_matrix)
        att_feats = att_feats + responses
        # Memory querying and responding for visual features

        att_masks = att_masks.unsqueeze(-2)
        if seq is not None:
            seq = seq[:, :-1]
            seq_mask = (seq.data > 0)
            seq_mask[:, 0] += True

            seq_mask = seq_mask.unsqueeze(-2)
            seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        else:
            seq_mask = None

        return att_feats, seq, att_masks, seq_mask

    def _forward(self, fc_feats, att_feats, seq, att_masks=None):
       att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)

       # Print the shapes of tensors before and after each operation
      #  print("Input shapes:")
      #  print("att_feats:", att_feats.shape)
      #  print("seq:", seq.shape)
      #  print("att_masks:", att_masks.shape)
      #  print("seq_mask:", seq_mask.shape)

       out = self.model(att_feats, seq, att_masks, seq_mask, memory_matrix=self.memory_matrix)
       outputs = F.log_softmax(self.logit(out), dim=-1)

      #  # Print the shape of the final output tensor
      #  print("Output shape:")
      #  print("outputs:", outputs.shape)

       return outputs


    def _save_attns(self, start=False):
        if start:
            self.attention_weights = []
        self.attention_weights.append([layer.src_attn.attn.cpu().numpy() for layer in self.model.decoder.layers])

    def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
        if len(state) == 0:
            ys = it.unsqueeze(1)
            past = [fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model),
                    fc_feats_ph.new_zeros(self.num_layers * 2, fc_feats_ph.shape[0], 0, self.d_model)]
        else:
            ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
            past = state[1:]
        out, past = self.model.decode(memory, mask, ys, subsequent_mask(ys.size(1)).to(memory.device), past=past,
                                      memory_matrix=self.memory_matrix)

        if not self.training:
            self._save_attns(start=len(state) == 0)
        return out[:, -1], [ys.unsqueeze(0)] + past


In [35]:
# visual_extractor.py

import torch
import torch.nn as nn
import torchvision.models as models


class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args['visual_extractor']
        self.pretrained = args['visual_extractor_pretrained']
        model = getattr(models, self.visual_extractor)(pretrained=self.pretrained)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

    def forward(self, images):
        patch_feats = self.model(images)
        avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))
        batch_size, feat_size, _, _ = patch_feats.shape
        patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
        return patch_feats, avg_feats


In [8]:
# models.py
import numpy as np
import torch
import torch.nn as nn




class BaseCMNModel(nn.Module):
    def __init__(self, args, tokenizer):
        super(BaseCMNModel, self).__init__()
        self.args = args
        self.tokenizer = tokenizer

        # Create the visual extractor
        self.visual_extractor = VisualExtractor(args)
        print("bellow the visual extractor")
        visual_out_features = 2048  # Adjust this based on the actual output size

        # Create the linear layer with the correct input and output size
        linear_input_size = visual_out_features
        linear_output_size = 980  # Adjust this value based on the actual output size
        print(f"Visual output size: {visual_out_features}")
        print(f"linear_input_size: {linear_input_size}")
        self.linear_layer = nn.Linear(linear_input_size, linear_output_size)
        print("bellow the linear layer")

        # Create the encoder-decoder module
        self.encoder_decoder = BaseCMN(args, tokenizer)
        print("in basecmnmodel bellow the encode_decoder")

        if args['dataset_name'] == 'iu_xray':
            self.forward = self.forward_iu_xray
            print("in basecmnmodel iu xray forward pass")
        else:
            self.forward = self.forward_mimic_cxr

    def forward_iu_xray(self, images, targets=None, mode='train', update_opts={}):
        # Forward pass through the visual extractor
        images = images.view(-1, *images.shape[2:])
        print("In forward_iu_xray in BaseCMNModel")

        _, avg_feats = self.visual_extractor(images)
        print("Below the avg_feats in BaseCMNModel")
        print(f"Shape of avg_feats before flattening: {avg_feats.shape}")

        # Flatten the features
        batch_size, num_features = avg_feats.size(0), np.prod(avg_feats.size()[1:])
        # Flatten the features
        avg_feats = avg_feats.view(batch_size, -1)
        print(f"Shape of avg_feats after flattening: {avg_feats.shape}")

        linear_output_size = 980
        avg_feats = avg_feats.view(batch_size, num_features)
        print(f"Shape of avg_feats after flattening: {avg_feats.shape}")

        # Forward pass through the linear layer
        linear_output = self.linear_layer(avg_feats)
        print("Below the linear_output in forward_iu_xray")

        # Continue with the rest of your model...
        # Example: Assuming self.encoder_decoder has a forward method
        if mode == 'train':
            print('in train')
            output = self.encoder_decoder(linear_output, targets=targets, mode='forward')  # Adjust arguments
            return output
        elif mode == 'sample':
            output, output_probs = self.encoder_decoder(linear_output, mode='sample', update_opts=update_opts)  # Adjust arguments
            return output, output_probs
        else:
            raise ValueError

    def forward_mimic_cxr(self, images, targets=None, mode='train', update_opts={}):
        att_feats, fc_feats = self.visual_extractor(images)
        if mode == 'train':
            output = self.encoder_decoder(fc_feats, att_feats, targets, mode='forward')
            return output
        elif mode == 'sample':
            output, output_probs = self.encoder_decoder(fc_feats, att_feats, mode='sample', update_opts=update_opts)
            return output, output_probs
        else:
            raise ValueError


In [9]:
# dataset.py

import json
import os

import torch
from PIL import Image
from torch.utils.data import Dataset


class BaseDataset(Dataset):
    def __init__(self, args, tokenizer, split, transform=None):
        self.image_dir = args['image_dir']
        self.ann_path = args['ann_path']
        self.max_seq_length = args['max_seq_length']
        self.split = split
        self.tokenizer = tokenizer
        self.transform = transform
        self.ann = json.loads(open(self.ann_path, 'r').read())
        self.examples = self.ann[self.split]
        for i in range(len(self.examples)):
            self.examples[i]['ids'] = tokenizer(self.examples[i]['report'])[:self.max_seq_length]
            self.examples[i]['mask'] = [1] * len(self.examples[i]['ids'])

    def __len__(self):
        return len(self.examples)


class IuxrayMultiImageDataset(BaseDataset):
    def __getitem__(self, idx):
        example = self.examples[idx]
        image_id = example['id']
        image_path = example['image_path']
        image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
        image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB')
        if self.transform is not None:
            image_1 = self.transform(image_1)
            image_2 = self.transform(image_2)
        image = torch.stack((image_1, image_2), 0)
        report_ids = example['ids']
        report_masks = example['mask']
        seq_length = len(report_ids)
        sample = (image_id, image, report_ids, report_masks, seq_length)
        return sample


class MimiccxrSingleImageDataset(BaseDataset):
    def __getitem__(self, idx):
        example = self.examples[idx]
        image_id = example['id']
        image_path = example['image_path']
        image = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB')
        image_id = os.path.join(self.image_dir, image_path[0])
        if self.transform is not None:
            image = self.transform(image)
        report_ids = example['ids']
        report_masks = example['mask']
        seq_length = len(report_ids)
        sample = (image_id, image, report_ids, report_masks, seq_length)
        return sample


In [10]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

class R2DataLoader(DataLoader):
    def __init__(self, args, tokenizer, split, shuffle):
        self.args = args
        self.dataset_name = args['dataset_name']
        self.batch_size = args['batch_size']
        self.shuffle = shuffle
        self.num_workers = args['num_workers']
        self.tokenizer = tokenizer
        self.split = split

        if split == 'train':
            # Training data transformations
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))])
        else:
            # Validation or test data transformations
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))])

        if self.dataset_name == 'iu_xray':
            self.dataset = IuxrayMultiImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)
        else:
            self.dataset = MimiccxrSingleImageDataset(self.args, self.tokenizer, self.split, transform=self.transform)

        self.init_kwargs = {
            'dataset': self.dataset,
            'batch_size': self.batch_size,
            'shuffle': self.shuffle,
            'collate_fn': self.collate_fn,
            'num_workers': self.num_workers
        }
        super().__init__(**self.init_kwargs)

    @staticmethod
    def collate_fn(data):
        image_id_batch, image_batch, report_ids_batch, report_masks_batch, seq_lengths_batch = zip(*data)

        image_batch = torch.stack([img.squeeze(0) for img in image_batch], 0)

        # print("Input image shape:", image_batch.shape)  # Add this line
        max_seq_length = max(seq_lengths_batch)

        target_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int)
        target_masks_batch = np.zeros((len(report_ids_batch), max_seq_length), dtype=int)

        for i, report_ids in enumerate(report_ids_batch):
            target_batch[i, :len(report_ids)] = report_ids

        for i, report_masks in enumerate(report_masks_batch):
            target_masks_batch[i, :len(report_masks)] = report_masks

        print("ending the dataloader")
        return image_id_batch, image_batch, torch.LongTensor(target_batch), torch.FloatTensor(target_masks_batch)


In [11]:

#loss.py

import torch
import torch.nn as nn


class LanguageModelCriterion(nn.Module):
    def __init__(self):
        super(LanguageModelCriterion, self).__init__()

    def forward(self, input, target, mask):
        # truncate to the same size
        target = target[:, :input.size(1)]
        mask = mask[:, :input.size(1)]
        output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask
        output = torch.sum(output) / torch.sum(mask)
        return output


def compute_loss(output, reports_ids, reports_masks):
    criterion = LanguageModelCriterion()
    loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean()
    return loss


In [12]:
# blueScore.py

# bleu_scorer.py
# David Chiang <chiang@isi.edu>

# Copyright (c) 2004-2006 University of Maryland. All rights
# reserved. Do not redistribute without permission from the
# author. Not for commercial use.

# Modified by:
# Hao Fang <hfang@uw.edu>
# Tsung-Yi Lin <tl483@cornell.edu>

# Last modified : Wed 22 May 2019 08:10:00 PM EDT
# By Sabarish Sivanath
# To support Python 3

'''Provides:
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
'''

import copy
import sys, math, re
from collections import defaultdict

def precook(s, n=4, out=False):
    """Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well."""
    words = s.split()
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)

def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.'''

    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram,count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen))/len(reflen)

    ## lhuang: N.B.: leave reflen computaiton to the very end!!

    ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)

    return (reflen, maxcounts)

def cook_test(test, refs , eff=None, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.'''

    reflen = refs[0]
    refmaxcounts = refs[1]

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
    else: ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]

    result['correct'] = [0]*n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)

    return result

class BleuScorer(object):
    """Bleu scorer.
    """

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        ''' copy the refs.'''
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        ''' singular instance '''

        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1])
                self.ctest.append(cooked_test) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

        self._score = None ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        '''return (bleu, len_ratio) pair'''
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs))
        self._score = None

        return self

    def rescore(self, new_test):
        ''' replace test(s) with new test(s), and returns the new score.'''

        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens))/len(reflens)
        elif option == "closest":
            reflen = min((abs(l-testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15 ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None: ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess','correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        /(float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1./(k+1)))
            ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1/ratio)

            if verbose > 1:
                print(comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1./(k+1)))
        ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1/ratio)

        if verbose > 0:
            print(totalcomps)
            print("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list


In [13]:
# blue.py

#!/usr/bin/env python
#
# File Name : bleu.py
#
# Description : Wrapper for BLEU scorer.
#
# Creation Date : 06-01-2015
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>

# Last modified : Wed 22 May 2019 08:10:00 PM EDT
# By Sabarish Sivanath
# To support Python 3

# from .bleu_scorer import BleuScorer


class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res, score_option = 'closest', verbose = 1):
        '''
        Inputs:
            gts - ground truths
            res - predictions
            score_option - {shortest, closest, average}
            verbose - 1 or 0
        Outputs:
            Blue scores
        '''
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            #assert(len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        score, scores = bleu_scorer.compute_score(option = score_option, verbose =verbose)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"


In [14]:
# meteor.py
import os
import subprocess
import threading

METEOR_JAR = '/content/R2GenCMN/pycocoevalcap/meteor/meteor-1.5.jar'

class Meteor:

    def __init__(self):
        current_directory = '/content/R2GenCMN/pycocoevalcap/meteor'  # replace with the actual directory path
        self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR,
                            '-', '-', '-stdio', '-l', 'en', '-norm']
        self.meteor_p = subprocess.Popen(self.meteor_cmd,
                                        cwd=current_directory,
                                        stdin=subprocess.PIPE,
                                        stdout=subprocess.PIPE,
                                        stderr=subprocess.PIPE,
                                        universal_newlines=True,
                                        bufsize=1)
        # Used to guarantee thread safety
        self.lock = threading.Lock()

    def compute_score(self, gts, res):
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()
        scores = []

        eval_line = 'EVAL'
        self.lock.acquire()
        for i in imgIds:
            assert(len(res[i]) == 1)
            stat = self._stat(res[i][0], gts[i])
            eval_line += ' ||| {}'.format(stat)

        self.meteor_p.stdin.write('{}\n'.format(eval_line))
        for i in range(0,len(imgIds)):
            scores.append(float(self.meteor_p.stdout.readline().strip()))
        score = float(self.meteor_p.stdout.readline().strip())
        self.lock.release()

        return score, scores

    def method(self):
        return "METEOR"

    def _stat(self, hypothesis_str, reference_list):
        # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
        hypothesis_str = hypothesis_str.replace('|||','').replace('  ',' ')
        score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
        self.meteor_p.stdin.write('{}\n'.format(score_line))
        return self.meteor_p.stdout.readline().strip()

    def _score(self, hypothesis_str, reference_list):
        self.lock.acquire()
        # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
        hypothesis_str = hypothesis_str.replace('|||','').replace('  ',' ')
        score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
        self.meteor_p.stdin.write('{}\n'.format(score_line))
        stats = self.meteor_p.stdout.readline().strip()
        eval_line = 'EVAL ||| {}'.format(stats)
        # EVAL ||| stats
        self.meteor_p.stdin.write('{}\n'.format(eval_line))
        score = float(self.meteor_p.stdout.readline().strip())
        # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
        # thanks for Andrej for pointing this out
        score = float(self.meteor_p.stdout.readline().strip())
        self.lock.release()
        return score

    def __del__(self):
        self.lock.acquire()
        self.meteor_p.stdin.close()
        self.meteor_p.kill()
        self.meteor_p.wait()
        self.lock.release()


In [15]:
# rough.py

#!/usr/bin/env python
#
# File Name : rouge.py
#
# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
#
# Creation Date : 2015-01-07 06:03
# Author : Ramakrishna Vedantam <vrama91@vt.edu>

import numpy as np
import pdb

def my_lcs(string, sub):
    """
    Calculates longest common subsequence for a pair of tokenized strings
    :param string : list of str : tokens from a string split using whitespace
    :param sub : list of str : shorter string, also split using whitespace
    :returns: length (list of int): length of the longest common subsequence between the two strings

    Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
    """
    if(len(string)< len(sub)):
        sub, string = string, sub

    lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]

    for j in range(1,len(sub)+1):
        for i in range(1,len(string)+1):
            if(string[i-1] == sub[j-1]):
                lengths[i][j] = lengths[i-1][j-1] + 1
            else:
                lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])

    return lengths[len(string)][len(sub)]

class Rouge():
    '''
    Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set

    '''
    def __init__(self):
        # vrama91: updated the value below based on discussion with Hovey
        self.beta = 1.2

    def calc_score(self, candidate, refs):
        """
        Compute ROUGE-L score given one candidate and references for an image
        :param candidate: str : candidate sentence to be evaluated
        :param refs: list of str : COCO reference sentences for the particular image to be evaluated
        :returns score: int (ROUGE-L score for the candidate evaluated against references)
        """
        assert(len(candidate)==1)
        assert(len(refs)>0)
        prec = []
        rec = []

        # split into tokens
        token_c = candidate[0].split(" ")

        for reference in refs:
            # split into tokens
            token_r = reference.split(" ")
            # compute the longest common subsequence
            lcs = my_lcs(token_r, token_c)
            prec.append(lcs/float(len(token_c)))
            rec.append(lcs/float(len(token_r)))

        prec_max = max(prec)
        rec_max = max(rec)

        if(prec_max!=0 and rec_max !=0):
            score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
        else:
            score = 0.0
        return score

    def compute_score(self, gts, res):
        """
        Computes Rouge-L score given a set of reference and candidate sentences for the dataset
        Invoked by evaluate_captions.py
        :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
        :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
        :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
        """
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        score = []
        for id in imgIds:
            hypo = res[id]
            ref  = gts[id]

            score.append(self.calc_score(hypo, ref))

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

        average_score = np.mean(np.array(score))
        return average_score, np.array(score)

    def method(self):
        return "Rouge"


In [16]:
#metrics.py

from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score

# from pycocoevalcap.bleu.bleu import Bleu
# from pycocoevalcap.meteor import Meteor
# from pycocoevalcap.rouge import Rouge


def compute_scores(gts, res):
    """
    Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap)

    :param gts: Dictionary with the image ids and their gold captions,
    :param res: Dictionary with the image ids ant their generated captions
    :print: Evaluation score (the mean of the scores of all the instances) for each measure
    """

    # Set up scorers
    scorers = [
        (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]),
        (Meteor(), "METEOR"),
        (Rouge(), "ROUGE_L")
    ]
    eval_res = {}
    # Compute score for each metric
    for scorer, method in scorers:
        try:
            score, scores = scorer.compute_score(gts, res, verbose=0)
        except TypeError:
            score, scores = scorer.compute_score(gts, res)
        if type(method) == list:
            for sc, m in zip(score, method):
                eval_res[m] = sc
        else:
            eval_res[method] = score
    return eval_res


def compute_mlc(gt, pred, label_set):
    res_mlc = {}
    avg_aucroc = 0
    for i, label in enumerate(label_set):
        res_mlc['AUCROC_' + label] = roc_auc_score(gt[:, i], pred[:, i])
        avg_aucroc += res_mlc['AUCROC_' + label]
    res_mlc['AVG_AUCROC'] = avg_aucroc / len(label_set)

    res_mlc['F1_MACRO'] = f1_score(gt, pred, average="macro")
    res_mlc['F1_MICRO'] = f1_score(gt, pred, average="micro")
    res_mlc['RECALL_MACRO'] = recall_score(gt, pred, average="macro")
    res_mlc['RECALL_MICRO'] = recall_score(gt, pred, average="micro")
    res_mlc['PRECISION_MACRO'] = precision_score(gt, pred, average="macro")
    res_mlc['PRECISION_MICRO'] = precision_score(gt, pred, average="micro")

    return res_mlc


class MetricWrapper(object):
    def __init__(self, label_set):
        self.label_set = label_set

    def __call__(self, gts, res, gts_mlc, res_mlc):
        eval_res = compute_scores(gts, res)
        eval_res_mlc = compute_mlc(gts_mlc, res_mlc, self.label_set)

        eval_res.update(**eval_res_mlc)
        return eval_res


In [17]:
# optimizers

import torch
from torch import optim


def build_optimizer(args, model):
    ve_params = list(map(id, model.visual_extractor.parameters()))
    ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
    optimizer = getattr(torch.optim, args['optim'])(
        [{'params': model.visual_extractor.parameters(), 'lr': args['lr_ve']},
         {'params': ed_params, 'lr': args['lr_ed']}],
        betas=args['adam_betas'],
        eps=args['adam_eps'],
        weight_decay=args['weight_decay'],
        amsgrad=args['amsgrad']
    )
    return optimizer


def build_lr_scheduler(args, optimizer):
    lr_scheduler = getattr(torch.optim.lr_scheduler, args['lr_scheduler'])(optimizer, args['step_size'], args['gamma'])
    return lr_scheduler


def set_lr(optimizer, lr):
    for group in optimizer.param_groups:
        group['lr'] = lr


def get_lr(optimizer):
    for group in optimizer.param_groups:
        return group['lr']


class NoamOpt(object):
    "Optim wrapper that implements rate."

    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))

    def __getattr__(self, name):
        return getattr(self.optimizer, name)

    def state_dict(self):
        state_dict = self.optimizer.state_dict()
        state_dict['_step'] = self._step
        return state_dict

    def load_state_dict(self, state_dict):
        if '_step' in state_dict:
            self._step = state_dict['_step']
            del state_dict['_step']
        self.optimizer.load_state_dict(state_dict)


def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
    optim_func = dict(Adam=torch.optim.Adam,
                      AdamW=torch.optim.AdamW)[optim_func]
    return NoamOpt(model.d_model, factor, warmup,
                   optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))


def build_noamopt_optimizer(args, model):
    ve_optimizer = getattr(torch.optim, args['optim'])(
        model.visual_extractor.parameters(),
        lr=0,
        betas=args['adam_betas'],
        eps=args['adam_eps'],
        weight_decay=args['weight_decay'],
        amsgrad=args['amsgrad']
    )
    ed_optimizer = get_std_opt(model.encoder_decoder, optim_func=args['optim'], factor=args['noamopt_factor'],
                               warmup=args['noamopt_warmup'])
    return ve_optimizer, ed_optimizer


class ReduceLROnPlateau(object):
    "Optim wrapper that implements rate."

    def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001,
                 threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode=mode, factor=factor,
                                                              patience=patience, verbose=verbose, threshold=threshold,
                                                              threshold_mode=threshold_mode, cooldown=cooldown,
                                                              min_lr=min_lr, eps=eps)
        self.optimizer = optimizer
        self.current_lr = get_lr(optimizer)

    def step(self):
        "Update parameters and rate"
        self.optimizer.step()

    def scheduler_step(self, val):
        self.scheduler.step(val)
        self.current_lr = get_lr(self.optimizer)

    def state_dict(self):
        return {'current_lr': self.current_lr,
                'scheduler_state_dict': self.scheduler.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict()}

    def load_state_dict(self, state_dict):
        if 'current_lr' not in state_dict:
            # it's normal optimizer
            self.optimizer.load_state_dict(state_dict)
            set_lr(self.optimizer, self.current_lr)  # use the lr fromt the option
        else:
            # it's a schduler
            self.current_lr = state_dict['current_lr']
            self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
            self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
            # current_lr is actually useless in this case

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))

    def __getattr__(self, name):
        return getattr(self.optimizer, name)


def build_plateau_optimizer(args, model):
    ve_optimizer = getattr(torch.optim, args['optim'])(
        model.visual_extractor.parameters(),
        lr=args['lr_ve'],
        betas=args['adam_betas'],
        eps=args['adam_eps'],
        weight_decay=args['weight_decay'],
        amsgrad=args['amsgrad']
    )
    ve_optimizer = ReduceLROnPlateau(ve_optimizer,
                                     factor=args['reduce_on_plateau_factor'],
                                     patience=args['reduce_on_plateau_patience'])
    ed_optimizer = getattr(torch.optim, args['optim'])(
        model.encoder_decoder.parameters(),
        lr=args['lr_ed'],
        betas=args['adam_betas'],
        eps=args['adam_eps'],
        weight_decay=args['weight_decay'],
        amsgrad=args['amsgrad']
    )
    ed_optimizer = ReduceLROnPlateau(ed_optimizer,
                                     factor=args['reduce_on_plateau_factor'],
                                     patience=args['reduce_on_plateau_patience'])

    return ve_optimizer, ed_optimizer


In [18]:
# tokenizer.py

import json
import re
from collections import Counter


class Tokenizer(object):
    def __init__(self, args):
        self.ann_path = args['ann_path']
        self.threshold = args['threshold']
        self.dataset_name = args['dataset_name']
        if self.dataset_name == 'iu_xray':
            self.clean_report = self.clean_report_iu_xray
        else:
            self.clean_report = self.clean_report_mimic_cxr
        self.ann = json.loads(open(self.ann_path, 'r').read())
        self.token2idx, self.idx2token = self.create_vocabulary()

    def create_vocabulary(self):
        total_tokens = []

        for example in self.ann['train']:
            tokens = self.clean_report(example['report']).split()
            for token in tokens:
                total_tokens.append(token)

        counter = Counter(total_tokens)
        vocab = [k for k, v in counter.items() if v >= self.threshold] + ['<unk>']
        vocab.sort()
        token2idx, idx2token = {}, {}
        for idx, token in enumerate(vocab):
            token2idx[token] = idx + 1
            idx2token[idx + 1] = token
        return token2idx, idx2token

    def clean_report_iu_xray(self, report):
        report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
            .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
            .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
            .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
                                        replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        return report

    def clean_report_mimic_cxr(self, report):
        report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
            .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace('  ', ' ') \
            .replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ').replace('  ', ' ') \
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
            .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
            .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
            .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
            .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '')
                                        .replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        return report

    def get_token_by_id(self, id):
        return self.idx2token[id]

    def get_id_by_token(self, token):
        if token not in self.token2idx:
            return self.token2idx['<unk>']
        return self.token2idx[token]

    def get_vocab_size(self):
        return len(self.token2idx)

    def __call__(self, report):
        tokens = self.clean_report(report).split()
        ids = []
        for token in tokens:
            ids.append(self.get_id_by_token(token))
        ids = [0] + ids + [0]
        return ids

    def decode(self, ids):
        txt = ''
        for i, idx in enumerate(ids):
            if idx > 0:
                if i >= 1:
                    txt += ' '
                txt += self.idx2token[idx]
            else:
                break
        return txt

    def decode_batch(self, ids_batch):
        out = []
        for ids in ids_batch:
            out.append(self.decode(ids))
        return out


In [19]:
# trainer.py

import logging
import os
from abc import abstractmethod

import torch
from numpy import inf


class BaseTrainer(object):
    def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler):
        self.args = args

        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(args['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self.epochs = self.args['epochs']
        self.save_period = self.args['save_period']

        self.mnt_mode = args['monitor_mode']
        self.mnt_metric = 'val_' + args['monitor_metric']
        self.mnt_metric_test = 'test_' + args['monitor_metric']
        assert self.mnt_mode in ['min', 'max']

        self.mnt_best = inf if self.mnt_mode == 'min' else -inf
        self.early_stop = getattr(self.args, 'early_stop', inf)

        self.start_epoch = 1
        self.checkpoint_dir = args['save_dir']

        self.best_recorder = {'val': {self.mnt_metric: self.mnt_best},
                              'test': {self.mnt_metric_test: self.mnt_best}}

        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        if args['resume'] is not None:
            self._resume_checkpoint(args['resume'])

    @abstractmethod
    def _train_epoch(self, epoch):
        raise NotImplementedError

    def train(self):
        not_improved_count = 0

        for epoch in range(self.start_epoch, self.epochs + 1):

            result = self._train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)
            self._record_best(log)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('\t{:15s}: {}'.format(str(key), value))
                print('\t{:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info("Validation performance didn\'t improve for {} epochs. " "Training stops.".format(
                        self.early_stop))
                    print("Validation performance didn\'t improve for {} epochs. " "Training stops.".format(
                        self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _record_best(self, log):
        improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][
            self.mnt_metric]) or \
                       (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric])
        if improved_val:
            self.best_recorder['val'].update(log)

        improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][
            self.mnt_metric_test]) or \
                        (self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][
                            self.mnt_metric_test])
        if improved_test:
            self.best_recorder['test'].update(log)

    def _print_best(self):
        self.logger.info('Best results (w.r.t {}) in validation set:'.format(self.args['monitor_metric']))
        for key, value in self.best_recorder['val'].items():
            self.logger.info('\t{:15s}: {}'.format(str(key), value))

        self.logger.info('Best results (w.r.t {}) in test set:'.format(self.args['monitor_metric']))
        for key, value in self.best_recorder['test'].items():
            self.logger.info('\t{:15s}: {}'.format(str(key), value))

    def _prepare_device(self, n_gpu_use):
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
                    n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False):
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best
        }
        filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth')
        torch.save(state, filename)
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        print("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")
            print("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        print("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
        print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))


class Trainer(BaseTrainer):
    def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader,
                 val_dataloader, test_dataloader):
        super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args, lr_scheduler)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader


    def _train_epoch(self, epoch):


        self.logger.info('[{}/{}] Start to train in the training set.'.format(epoch, self.epochs))
        print('[{}/{}] Start to train in the training set.'.format(epoch, self.epochs))
        train_loss = 0
        self.model.train()
        for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader):

            images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), \
                                                 reports_masks.to(self.device)
            output = self.model(images, reports_ids, mode='train')
            loss = self.criterion(output, reports_ids, reports_masks)
            train_loss += loss.item()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if batch_idx % self.args['log_period'] == 0:
                self.logger.info('[{}/{}] Step: {}/{}, Training Loss: {:.5f}.'
                                 .format(epoch, self.epochs, batch_idx, len(self.train_dataloader),
                                         train_loss / (batch_idx + 1)))
                print('[{}/{}] Step: {}/{}, Training Loss: {:.5f}.'
                                 .format(epoch, self.epochs, batch_idx, len(self.train_dataloader),
                                         train_loss / (batch_idx + 1)))

        log = {'train_loss': train_loss / len(self.train_dataloader)}

        self.logger.info('[{}/{}] Start to evaluate in the validation set.'.format(epoch, self.epochs))
        self.model.eval()
        with torch.no_grad():
            val_gts, val_res = [], []
            for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)

                output, _ = self.model(images, mode='sample')
                reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
                ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
                val_res.extend(reports)
                val_gts.extend(ground_truths)

            val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)},
                                       {i: [re] for i, re in enumerate(val_res)})
            log.update(**{'val_' + k: v for k, v in val_met.items()})

        self.logger.info('[{}/{}] Start to evaluate in the test set.'.format(epoch, self.epochs))
        self.model.eval()
        with torch.no_grad():
            test_gts, test_res = [], []
            for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
                ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
                test_res.extend(reports)
                test_gts.extend(ground_truths)

            test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
                                        {i: [re] for i, re in enumerate(test_res)})
            log.update(**{'test_' + k: v for k, v in test_met.items()})

        self.lr_scheduler.step()

        return log


In [20]:
torch.cuda.empty_cache()

In [None]:
# For PyTorch
import torchvision.models as models
print(dir(models))

['AlexNet', 'AlexNet_Weights', 'ConvNeXt', 'ConvNeXt_Base_Weights', 'ConvNeXt_Large_Weights', 'ConvNeXt_Small_Weights', 'ConvNeXt_Tiny_Weights', 'DenseNet', 'DenseNet121_Weights', 'DenseNet161_Weights', 'DenseNet169_Weights', 'DenseNet201_Weights', 'EfficientNet', 'EfficientNet_B0_Weights', 'EfficientNet_B1_Weights', 'EfficientNet_B2_Weights', 'EfficientNet_B3_Weights', 'EfficientNet_B4_Weights', 'EfficientNet_B5_Weights', 'EfficientNet_B6_Weights', 'EfficientNet_B7_Weights', 'EfficientNet_V2_L_Weights', 'EfficientNet_V2_M_Weights', 'EfficientNet_V2_S_Weights', 'GoogLeNet', 'GoogLeNetOutputs', 'GoogLeNet_Weights', 'Inception3', 'InceptionOutputs', 'Inception_V3_Weights', 'MNASNet', 'MNASNet0_5_Weights', 'MNASNet0_75_Weights', 'MNASNet1_0_Weights', 'MNASNet1_3_Weights', 'MaxVit', 'MaxVit_T_Weights', 'MobileNetV2', 'MobileNetV3', 'MobileNet_V2_Weights', 'MobileNet_V3_Large_Weights', 'MobileNet_V3_Small_Weights', 'RegNet', 'RegNet_X_16GF_Weights', 'RegNet_X_1_6GF_Weights', 'RegNet_X_32GF_

In [22]:
!pip install timm

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.16


In [None]:
import timm
print(timm.list_models())


['bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2

In [None]:
#twins_pcpvt_base
# twins_pcpvt_large
# twins_pcpvt_small
# vgg11
#vit_tiny_r_s16_p8_384
# vit_tiny_r_s16_p8_224
# vit_small_r26_s32_224
# vit_large_r50_s32_384
# vit_large_r50_s32_224
# resnest200e
# vgg19_bn
# seresnext101_64x4d
# volo_d1_224

# densenet121




In [None]:
from timm import create_model
model = create_model('wide_resnet101_2', pretrained=True)
model

In [23]:
import torch
import torch.nn as nn
from timm.models import create_model

class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args['visual_extractor']
        self.pretrained = args['visual_extractor_pretrained']

        # Create a new first layer to handle (3, 64) input channels
        new_first_layer = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Load the pretrained Swin Transformer model
        model = create_model(self.visual_extractor, pretrained=self.pretrained, first_layer=new_first_layer)

        # Replace the initial convolutional layer
        model.patch_embed.proj = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

        # Determine the input size for layer normalization based on the actual size of the output
        _, last_size, _, _ = self.model(torch.randn(1, 3, 224, 224)).shape
        self.norm = nn.LayerNorm([last_size, last_size, 64])

    def forward(self, images):
        patch_feats = self.model(images)
        print("Below the model(images) in Visual Extractor")

        B, C, H, W = patch_feats.shape
        patch_feats = patch_feats.permute(0, 2, 3, 1).contiguous().view(B * H * W, C)

        # Apply layer normalization
        patch_feats = self.norm(patch_feats.view(B, H, W, C).permute(0, 3, 1, 2)).view(B, C, H, W)

        avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))

        return patch_feats, avg_feats


In [25]:
!pip install torchviz


Collecting torchviz
  Downloading torchviz-0.0.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torchviz
  Building wheel for torchviz (setup.py) ... [?25l[?25hdone
  Created wheel for torchviz: filename=torchviz-0.0.2-py3-none-any.whl size=4131 sha256=8cdd5c99d8c63e484c6aeb60bdda6974de14f9ff16a0722f473535bbc0c0cae9
  Stored in directory: /root/.cache/pip/wheels/4c/97/88/a02973217949e0db0c9f4346d154085f4725f99c4f15a87094
Successfully built torchviz
Installing collected packages: torchviz
Successfully installed torchviz-0.0.2


In [27]:
import torch
import torch.nn as nn
from torchviz import make_dot

# Define the VisualExtractor class
class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args['visual_extractor']
        self.pretrained = args['visual_extractor_pretrained']
        model = create_model(self.visual_extractor, pretrained=self.pretrained)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

    def forward(self, images):
        patch_feats = self.model(images)
        avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))
        batch_size, feat_size, _, _ = patch_feats.shape
        patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
        return patch_feats, avg_feats


# # Save the graph to a file
# graph.render("visual_extractor_graph")


In [24]:
# visual_extractor.py


import torch
import torch.nn as nn
# import torchvision.models as models
from timm import create_model


class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args['visual_extractor']
        self.pretrained = args['visual_extractor_pretrained']
        model = create_model(self.visual_extractor, pretrained=self.pretrained)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

    def forward(self, images):
        patch_feats = self.model(images)
        print("in visual extractor below the model(image)")
        avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))
        batch_size, feat_size, _, _ = patch_feats.shape
        patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
        return patch_feats, avg_feats




In [None]:
torch.cuda.empty_cache()

In [None]:

#train.py

import argparse
import torchvision.models as models

import numpy as np
import torch
import os



# from models.models import BaseCMNModel
# from modules.dataloaders import R2DataLoader
# from modules.loss import compute_loss
# from modules.metrics import compute_scores
# from modules.optimizers import build_optimizer, build_lr_scheduler
# from modules.tokenizers import Tokenizer
# from modules.trainer import Trainer
# resnet152
# swinv2_cr_small_ns_224
# swin_s3_small_224
# tiny_vit_5m_224
# vit_base_patch16_clip_224
# densenet121
# res2net101_26w_4s
# res2net101d
# tf_efficientnet_lite4
# tf_mixnet_l


def main():
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    args = {
        'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'num_workers': 2,
        'threshold':3,
        'batch_size': 10,
        'visual_extractor': 'resnet101',                 # vgg11 , vgg13
        'visual_extractor_pretrained': True,           # working models (resnetrs101,)
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
       'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 6,
        'save_dir': '/content/drive/MyDrive/111',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,

    }

    torch.manual_seed(args['seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args['seed'])

    # create tokenizer
    tokenizer = Tokenizer(args)

    # create data loader
    train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
    val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
    test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

    # build model architecture
    model = BaseCMNModel(args, tokenizer)

    # get function handles of loss and metrics
    criterion = compute_loss
    metrics = compute_scores

    # build optimizer, learning rate scheduler
    optimizer = build_optimizer(args, model)
    lr_scheduler = build_lr_scheduler(args, optimizer)

    # build trainer and start to train
    trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
    trainer.train()



if __name__ == '__main__':
    main()


model.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

bellow the visual extractor
Visual output size: 2048
linear_input_size: 2048
bellow the linear layer
in basecmnmodel bellow the encode_decoder
in basecmnmodel iu xray forward pass
[1/6] Start to train in the training set.
ending the dataloader
ending the dataloader
In forward_iu_xray in BaseCMNModel
in visual extractor below the model(image)
Below the avg_feats in BaseCMNModel
Shape of avg_feats before flattening: torch.Size([20, 2048])
Shape of avg_feats after flattening: torch.Size([20, 2048])
Shape of avg_feats after flattening: torch.Size([20, 2048])
Below the linear_output in forward_iu_xray
in train
ending the dataloader


TypeError: ignored

In [None]:
############# here we test the model ###############3

In [28]:
#tester.py
import logging
import os
from abc import abstractmethod

import cv2
import numpy as np
import pandas as pd
import spacy
import torch
from tqdm import tqdm




class BaseTester(object):
    def __init__(self, model, criterion, metric_ftns, args):
        self.args = args

        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(args['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns

        self.epochs = self.args['epochs']
        self.save_dir = self.args['save_dir']
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self._load_checkpoint(args['load'])

    @abstractmethod
    def test(self):
        raise NotImplementedError

    @abstractmethod
    def plot(self):
        raise NotImplementedError

    def _prepare_device(self, n_gpu_use):
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
                    n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _load_checkpoint(self, load_path):
        load_path = str(load_path)
        self.logger.info("Loading checkpoint: {} ...".format(load_path))
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['state_dict'])


class Tester(BaseTester):
    def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
        super(Tester, self).__init__(model, criterion, metric_ftns, args)
        self.test_dataloader = test_dataloader

    def test(self):
        self.logger.info('Start to evaluate in the test set.')
        self.model.eval()
        log = dict()
        with torch.no_grad():
            test_gts, test_res = [], []
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
                ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
                test_res.extend(reports)
                test_gts.extend(ground_truths)

            test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
                                        {i: [re] for i, re in enumerate(test_res)})
            log.update(**{'test_' + k: v for k, v in test_met.items()})
            print(log)

            test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts)
            test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False)
            test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False)

        return log

    def plot(self):
        assert self.args['batch_size'] == 1 and self.args['beam_size'] == 1
        self.logger.info('Start to plot attention weights in the test set.')
        os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
        ner = spacy.load("en_core_sci_sm")
        mean = torch.tensor((0.485, 0.456, 0.406))
        std = torch.tensor((0.229, 0.224, 0.225))
        mean = mean[:, None, None]
        std = std[:, None, None]

        self.model.eval()
        with torch.no_grad():
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
                report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()

                char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]

                attention_weights = self.model.encoder_decoder.attention_weights[:-1]
                assert len(attention_weights) == len(report)
                for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
                    for layer_idx, attn in enumerate(attns):
                        os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)

                        heatmap = generate_heatmap(image, attn.mean(1).squeeze())
                        cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)),
                                    heatmap)

                for ne_idx, ne in enumerate(ner(" ".join(report)).ents):
                    for layer_idx in range(len(attention_weights[0])):
                        os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)
                        attn = [attns[layer_idx] for attns in
                                attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]]
                        attn = np.concatenate(attn, axis=2)
                        heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze())
                        cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)),
                                    heatmap)


Here we test the model.

In [29]:
#tester.py
import logging
import os
from abc import abstractmethod

import cv2
import numpy as np
import pandas as pd
import spacy
import torch
from tqdm import tqdm




class BaseTester(object):
    def __init__(self, model, criterion, metric_ftns, args):
        self.args = args

        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(args['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns

        self.epochs = self.args['epochs']
        self.save_dir = self.args['save_dir']
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self._load_checkpoint(args['load'])

    @abstractmethod
    def test(self):
        raise NotImplementedError

    @abstractmethod
    def plot(self):
        raise NotImplementedError

    def _prepare_device(self, n_gpu_use):
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
                    n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _load_checkpoint(self, load_path):
        load_path = str(load_path)
        self.logger.info("Loading checkpoint: {} ...".format(load_path))
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['state_dict'])


class Tester(BaseTester):
    def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
        super(Tester, self).__init__(model, criterion, metric_ftns, args)
        self.test_dataloader = test_dataloader

    def test(self):
        self.logger.info('Start to evaluate in the test set.')
        self.model.eval()
        log = dict()
        with torch.no_grad():
            test_gts, test_res = [], []
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
                ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
                test_res.extend(reports)
                test_gts.extend(ground_truths)

            test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
                                        {i: [re] for i, re in enumerate(test_res)})
            log.update(**{'test_' + k: v for k, v in test_met.items()})
            print(log)

            test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts)
            test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False)
            test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False)

        return log

    def plot(self):
        assert self.args['batch_size'] == 1 and self.args['beam_size'] == 1
        self.logger.info('Start to plot attention weights in the test set.')
        os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
        ner = spacy.load("en_core_sci_sm")
        mean = torch.tensor((0.485, 0.456, 0.406))
        std = torch.tensor((0.229, 0.224, 0.225))
        mean = mean[:, None, None]
        std = std[:, None, None]

        self.model.eval()
        with torch.no_grad():
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
                report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()

                char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]

                attention_weights = self.model.encoder_decoder.attention_weights[:-1]
                assert len(attention_weights) == len(report)
                for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
                    for layer_idx, attn in enumerate(attns):
                        os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)

                        heatmap = generate_heatmap(image, attn.mean(1).squeeze())
                        cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)),
                                    heatmap)

                for ne_idx, ne in enumerate(ner(" ".join(report)).ents):
                    for layer_idx in range(len(attention_weights[0])):
                        os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)
                        attn = [attns[layer_idx] for attns in
                                attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]]
                        attn = np.concatenate(attn, axis=2)
                        heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze())
                        cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)),
                                    heatmap)


##the below one is working well for saving the report files

In [31]:
import torch
import torch.nn as nn
from torchviz import make_dot

# Define the VisualExtractor class
class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args['visual_extractor']
        self.pretrained = args['visual_extractor_pretrained']
        model = create_model(self.visual_extractor, pretrained=self.pretrained)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

    def forward(self, images):
        patch_feats = self.model(images)
        avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))
        batch_size, feat_size, _, _ = patch_feats.shape
        patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
        return patch_feats, avg_feats


# # Save the graph to a file
# graph.render("visual_extractor_graph")


In [41]:
import torch
import torch.nn as nn
from torchviz import make_dot
from timm import create_model  # Add import statement for create_model

# Define the VisualExtractor class
class VisualExtractor(nn.Module):
    def __init__(self, args):
        super(VisualExtractor, self).__init__()
        self.visual_extractor = args['visual_extractor']
        self.pretrained = args['visual_extractor_pretrained']
        model = create_model(self.visual_extractor, pretrained=self.pretrained)
        modules = list(model.children())[:-2]
        self.model = nn.Sequential(*modules)
        self.avg_fnt = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

    def forward(self, images):
        patch_feats = self.model(images)
        avg_feats = self.avg_fnt(patch_feats).squeeze().reshape(-1, patch_feats.size(1))
        batch_size, feat_size, _, _ = patch_feats.shape
        patch_feats = patch_feats.reshape(batch_size, feat_size, -1).permute(0, 2, 1)
        return patch_feats, avg_feats


def main():
    args = {
        'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/resnest200e_save_checkpoints/model_best.pth",

    }

     # Instantiate the VisualExtractor model
    model = VisualExtractor(args)

     # Create a random input tensor
    images = torch.randn(1, 3, 224, 224)  # Batch size 1, RGB channels, image size 224x224

    # Pass the input through the model
    patch_feats, avg_feats = model(images)

    # Print the length of the output layer
    print("Length of the output layer:", patch_feats.size(1))

    # Print the shapes of the output tensors
    print("Shape of patch_feats:", patch_feats.shape)
    print("Shape of avg_feats:", avg_feats.shape)

    #################   input and output tensors ################33

    # # Create a random input tensor
    # images = torch.randn(1, 3, 224, 224)  # Batch size 1, RGB channels, image size 224x224

    # print("Input tensor shape:", images.shape)
    # print("Input tensor values:", images)

    # # Pass the input tensor through the model
    # patch_feats, avg_feats = model(images)

    # print("Output patch features shape:", patch_feats.shape)
    # print("Output patch features:", patch_feats)

    # print("Output average features shape:", avg_feats.shape)
    # print("Output average features:", avg_feats)

############ Generate Graph of the visual extractor #################
    # # Instantiate the VisualExtractor model
    # model = VisualExtractor(args)

    # # Create a random input tensor
    # images = torch.randn(1, 3, 224, 224)  # Batch size 1, RGB channels, image size 224x224

    # # Create a graph of the model
    # output = model(images)
    # graph = make_dot(output)  # Corrected to pass the entire model's output

    # # Save the graph to a file
    # graph.render("visual_extractor_graph")

if __name__ == '__main__':
    main()


Length of the output layer: 49
Shape of patch_feats: torch.Size([1, 49, 2048])
Shape of avg_feats: torch.Size([1, 2048])


In [None]:
import argparse
import numpy as np
import torch

# from models.models import BaseCMNModel
# from modules.dataloaders import R2DataLoader
# from modules.loss import compute_loss
# from modules.metrics import compute_scores
# from modules.tokenizers import Tokenizer
# from modules.tester import Tester




def main():
    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/resnest200e_save_checkpoints/model_best.pth",

    }

     # Create tokenizer
    tokenizer = Tokenizer(args)

    # Build model architecture
    model = BaseCMNModel(args, tokenizer)

    return model


if __name__ == '__main__':


    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/resnest200e_save_checkpoints/model_best.pth",

    }





    tokenizer = Tokenizer(args)
    main()
    model = main()





    # Check if the model is not None
    if model is not None:
        # Move the model to GPU if it's not already there
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)


        # Image paths
        # image_path_1 = '/content/drive/MyDrive/iu_xray/images/CXR1000_IM-0003/0.png'
        # image_path_2 = '/content/drive/MyDrive/iu_xray/images/CXR1001_IM-0004/1.png'

        # Preprocess the images
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        image_1 = Image.open(image_path_1).convert('RGB')
        # image_2 = Image.open(image_path_2).convert('RGB')

        image_1 = transform(image_1)
        # image_2 = transform(image_2)

        # Stack the images along a new dimension
        image = torch.stack([image_1, image_1], dim=0)

        # Add batch dimension
        image = image.unsqueeze(0)

        # Move the input image to the same device as the model
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        image = image.to(device)



        # Generate the report
        with torch.no_grad():
            output, _ = model(image, mode='sample')

        generated_report = tokenizer.decode_batch(output.cpu().detach().numpy())[0]

        # Print the generated report
        print(f"Generated Report for Image 1: {generated_report}")

    else:
        print("Error: Model is None. Check the main function.")



Generated Report for Image 1: tortuosity acromioclavicular top acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular granulomas acromioclavicular subtle shadows wedging acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular dual-lumen shift cavoatrial acromioclavicular acromioclavicular posteriorly acromioclavicular acromioclavicular concerning available shift been hyperexpanded osteophytes hyperexpanded underlying shift coarsened underlying shift acromioclavicular acromioclavicular acromioclavicular acromioclavicular acromioclavicular base acromioclavicular base acromioclavicular available amount acromioclavicular in shift


In [None]:
import os
import torch
from PIL import Image
from torchvision import transforms

def preprocess_images(image_paths, transform=None):
    images = []

    for image_path in image_paths:
        # Load the image
        image = Image.open(image_path).convert('RGB')

        # Apply transformations if provided
        if transform is not None:
            image = transform(image)

        images.append(image)

    # Ensure all images have three channels
    images = [img if img.shape[0] == 3 else img.repeat(3, 1, 1) for img in images]

    # Stack the images along the 0th dimension to create a tensor
    images = torch.stack(images, 0)

    # Placeholder values for report-related information (as they are not used during testing)
    image_ids = None
    report_ids = None
    report_masks = None
    seq_lengths = None

    return image_ids, images, report_ids, report_masks, seq_lengths


# Define the generate_reports_for_images function
def generate_reports_for_images(model, image_paths, tokenizer, device='cuda'):
    # Load and preprocess the images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    image_ids, images, report_ids, report_masks, seq_lengths = preprocess_images(image_paths, transform)
    images = images.to(device)  # Move to the specified device

    print(f"Input images size: {images.size()}")

    # Generate reports
    model.eval()
    with torch.no_grad():
        outputs, _ = model(images, mode='sample')
        generated_reports = tokenizer.decode_batch(outputs.cpu().numpy())

    return generated_reports



In [None]:

import argparse
import numpy as np
import torch

# from models.models import BaseCMNModel
# from modules.dataloaders import R2DataLoader
# from modules.loss import compute_loss
# from modules.metrics import compute_scores
# from modules.tokenizers import Tokenizer
# from modules.tester import Tester




def main():
    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/CMN_saved_checkpoints/current_checkpoint.pth",

    }

     # Create tokenizer
    tokenizer = Tokenizer(args)

    # Build model architecture
    model = BaseCMNModel(args, tokenizer)

    model = model.to('cuda')  # Move model to GPU

    # Load the trained model checkpoint
    checkpoint = torch.load(args['load'], map_location='cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    # Specify the paths of the two images
    image_paths = [
        '/content/drive/MyDrive/iu_xray/images/CXR1000_IM-0003/0.png',
        '/content/drive/MyDrive/iu_xray/images/CXR1001_IM-0004/1.png'
    ]

    # Generate reports for the two images
    generated_reports = generate_reports_for_images(model, image_paths, tokenizer)

    # Print the generated reports
    for i, report in enumerate(generated_reports):
        print(f"Generated Report for Image {i + 1}: {report}")



if __name__ == '__main__':
    main()


model.safetensors:   0%|          | 0.00/179M [00:00<?, ?B/s]

RuntimeError: ignored

here we generate report

In [None]:
import argparse
import numpy as np
import torch

# from models.models import BaseCMNModel
# from modules.dataloaders import R2DataLoader
# from modules.loss import compute_loss
# from modules.metrics import compute_scores
# from modules.tokenizers import Tokenizer
# from modules.tester import Tester




def main():
    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/CMN_saved_checkpoints/current_checkpoint.pth",

    }

     # Create tokenizer
    tokenizer = Tokenizer(args)

    # Build model architecture
    model = BaseCMNModel(args, tokenizer)

    return model


if __name__ == '__main__':


    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/CMN_saved_checkpoints/current_checkpoint.pth",

    }





    tokenizer = Tokenizer(args)
    main()
    model = main()





    # Check if the model is not None
    if model is not None:
        # Move the model to GPU if it's not already there
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        # Image paths
        image_path_1 = '/content/drive/MyDrive/iu_xray/images/CXR1000_IM-0003/0.png'
        image_path_2 = '/content/drive/MyDrive/iu_xray/images/CXR1001_IM-0004/1.png'

        # ... (rest of your code)

        # Move the input image to the same device as the model
        image = image.to(device)

        # Generate the report
        with torch.no_grad():
            output, _ = model(image, mode='sample')

        generated_report = tokenizer.decode_batch(output.cpu().detach().numpy())[0]

        # Print the generated report
        print(f"Generated Report for Image 1: {generated_report}")

    else:
        print("Error: Model is None. Check the main function.")



Generated Report for Image 1: clips probable probable probable probable infrahilar carina probable infrahilar granulomatous infrahilar carina probable granulomatous probable probable infrahilar granulomatous probable granulomatous probable probable infrahilar granulomatous infrahilar infrahilar infrahilar granulomatous infrahilar 6th projected 6th material 6th cardiopulmonary probable infrahilar right opacity lung demonstrates probable infrahilar infrahilar granulomatous infrahilar granulomatous infrahilar infrahilar infrahilar enlarged lung cardiopulmonary probable suspicious granulomatous infrahilar granulomatous infrahilar carina


here we print report corresponding an image

In [None]:
import logging
import os
from abc import abstractmethod

import cv2
import numpy as np
import pandas as pd
import spacy
import torch
from tqdm import tqdm

# from modules.utils import generate_heatmap


class BaseTester(object):
    def __init__(self, model, criterion, metric_ftns, args):
        self.args = args

        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(args['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns

        self.epochs = self.args["epochs"]
        self.save_dir = self.args['save_dir']
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self._load_checkpoint(args['load'])

    @abstractmethod
    def test(self):
        raise NotImplementedError

    @abstractmethod
    def plot(self):
        raise NotImplementedError

    def _prepare_device(self, n_gpu_use):
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
                    n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _load_checkpoint(self, load_path):
        load_path = str(load_path)
        self.logger.info("Loading checkpoint: {} ...".format(load_path))
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['state_dict'])


class Tester(BaseTester):
    def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
        super(Tester, self).__init__(model, criterion, metric_ftns, args)
        self.test_dataloader = test_dataloader

    def test(self):
        self.logger.info('Start to evaluate in the test set.')
        self.model.eval()
        log = dict()
        with torch.no_grad():
            test_gts, test_res = [], []  # here we are storing the reports

            print("Generating Reports:")
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
               # Assuming you want to process the first image in the batch
               single_image = images[4]  # Change this line to select a different image if needed
               single_image = single_image.unsqueeze(0)  # Add batch dimension

            # Print image ID
            image_id = images_id[4]
            print(f"Image ID: {image_id}")

            print(f"Image batch size: {images.size()}")
            print(f"Reports IDs batch size: {reports_ids.size()}")
            print(f"Reports Masks batch size: {reports_masks.size()}")

            # Process the image and generate the report
            images, reports_ids, reports_masks = single_image.to(self.device), reports_ids.to(self.device), reports_masks.to(self.device)
            output, _ = self.model(images, mode='sample')
            generated_report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0]

            # Print the generated report
            print(f"Generated Report: {generated_report}\n")

            # Additional print statement to show progress
            if (batch_idx + 1) % 10 == 0:
               print(f"Processed {batch_idx + 1} batches")

            print("Reports Generated Successfully!")

            # ... (rest of the method remains unchanged)

            return log


    def plot(self):
        assert self.args['batch_size'] == 1 and self.args['beam_size'] == 1
        self.logger.info('Start to plot attention weights in the test set.')
        os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
        ner = spacy.load("en_core_sci_sm")
        mean = torch.tensor((0.485, 0.456, 0.406))
        std = torch.tensor((0.229, 0.224, 0.225))
        mean = mean[:, None, None]
        std = std[:, None, None]

        self.model.eval()
        with torch.no_grad():
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
                report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()
                print(report)

                char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]

                attention_weights = self.model.encoder_decoder.attention_weights[:-1]
                assert len(attention_weights) == len(report)
                for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
                    for layer_idx, attn in enumerate(attns):
                        os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)

                        heatmap = generate_heatmap(image, attn.mean(1).squeeze())



In [None]:
from PIL import Image
import os
import torch
from torchvision import transforms

# Load the image
image_path_1 = "/content/drive/MyDrive/iu_xray/images/CXR1728_IM-0479/0.png"  # Replace with the path to your first image
image_path_2 = "/content/drive/MyDrive/iu_xray/images/CXR1728_IM-0479/1.png"  # Replace with the path to your second image

# Extract image IDs from file paths
image_id_1 = os.path.basename(image_path_1)
image_id_2 = os.path.basename(image_path_2)

image_1 = Image.open(image_path_1).convert('RGB')
image_2 = Image.open(image_path_2).convert('RGB')

# Preprocess the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

image_1 = transform(image_1)
image_2 = transform(image_2)

# Stack the images along a new dimension
image = torch.stack([image_1, image_2], dim=0)

# Add batch dimension
image = image.unsqueeze(0)

# Move the input image to the same device as the model
device = next(model.parameters()).device
image = image.to(device)

# Print the image IDs
print(f"Image ID 1: {image_id_1}")
print(f"Image ID 2: {image_id_2}")

# Generate the report
with torch.no_grad():
    output, _ = model(image, mode='sample')

generated_report = model.tokenizer.decode_batch(output.cpu().detach().numpy())[0]

# Print the generated report
print(f"Generated Report for Image ID 1: {generated_report}")


Image ID 1: 0.png
Image ID 2: 1.png
Generated Report for Image ID 1: granulomatous infrahilar probable right granulomatous probable infrahilar granulomatous overall granulomatous infrahilar granulomatous azygos 6th granulomatous probable 6th bilaterallythere granulomatous 6th granulomatous infrahilar granulomatous infrahilar 6th infrahilar overlie 6th granulomatous probable right right 6th residual infrahilar skin granulomatous inferior one demonstrates tissues demonstrates demonstrates probable granulomatous probable demonstrates tissues cardiopulmonary granulomatous probable lung residual obscured residual tissues granulomatous infrahilar granulomatous infrahilar


here we are generating report from a single image

In [None]:

import argparse
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

# from models.models import BaseCMNModel
# from modules.dataloaders import R2DataLoader
# from modules.loss import compute_loss
# from modules.metrics import compute_scores
# from modules.tokenizers import Tokenizer
# from modules.tester import Tester

def main():
    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/CMN_saved_checkpoints/current_checkpoint.pth",

    }

    # Fix random seeds
    torch.manual_seed(args['seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args['seed'])

    # Create tokenizer
    tokenizer = Tokenizer(args)

    # Create data loader
    test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

    # Build model architecture
    model = BaseCMNModel(args, tokenizer)


if __name__ == '__main__':
    main()

from PIL import Image
import os
import shutil
import torch
from torchvision import transforms

# Original image path
original_image_path = "/content/drive/MyDrive/iu_xray/images/CXR1001_IM-0004/1.png"  # Replace with the path to your image

# Specify the directory where you want to create a copy
copy_directory = "/content/drive/MyDrive/iu_xray/images/CXR1728_IM-0479/"
copy_image_path = os.path.join(copy_directory, "0_copy.png")

# Make a copy of the image
shutil.copy(original_image_path, copy_image_path)

# Extract image IDs from file paths
original_image_id = os.path.basename(original_image_path)
copy_image_id = os.path.basename(copy_image_path)

# Open the images and convert to RGB
original_image = Image.open(original_image_path).convert('RGB')
copy_image = Image.open(copy_image_path).convert('RGB')

# Preprocess the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

original_image = transform(original_image)
copy_image = transform(copy_image)

# Stack the images along a new dimension
image = torch.stack([original_image, original_image], dim=0)

# Add batch dimension
image = image.unsqueeze(0)

# Move the input image to the same device as the model
device = next(model.parameters()).device
image = image.to(device)

# Print the image IDs
print(f"Original Image ID: {original_image_id}")
# print(f"Copy Image ID: {copy_image_id}")

# Generate the report
with torch.no_grad():
    output, _ = model(image, mode='sample')

generated_report = model.tokenizer.decode_batch(output.cpu().detach().numpy())[0]

# Print the generated report
print(f"Generated Report for Original Image: {generated_report}")




Original Image ID: 1.png
Generated Report for Original Image: irregularity granulomatous probable infrahilar cardiopulmonary probable infrahilar granulomatous probable infrahilar segment overall granulomatous inferior granulomatous probable infrahilar limits rib significantly rib cardiopulmonary 6th infrahilar infrahilar granulomatous infrahilar granulomatous infrahilar limits several obscured granulomatous caval both infrahilar granulomatous demonstrates elevation granulomatous demonstrates tissues demonstrates overall overlie granulomatous infrahilar pneumothoraces limits granulomatous probable lung typical fractures 6th carina demonstrates tissues 6th infrahilar


here we generate report from two images

In [None]:

import argparse
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
import os

# from models.models import BaseCMNModel
# from modules.dataloaders import R2DataLoader
# from modules.loss import compute_loss
# from modules.metrics import compute_scores
# from modules.tokenizers import Tokenizer
# from modules.tester import Tester

def main():
    args = {
         'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnet101',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/CMN_saved_checkpoints/current_checkpoint.pth",

    }

    # Fix random seeds
    torch.manual_seed(args['seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args['seed'])

    # Create tokenizer
    tokenizer = Tokenizer(args)

    # Create data loader
    # test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

    # Build model architecture
    model = BaseCMNModel(args, tokenizer)



    # Load the image
    image_path_1 = "/content/drive/MyDrive/iu_xray/images/CXR1728_IM-0479/0.png"  # Replace with the path to your first image
    image_path_2 = "/content/drive/MyDrive/iu_xray/images/CXR1728_IM-0479/1.png"  # Replace with the path to your second image

    # Extract image IDs from file paths
    image_id_1 = os.path.basename(image_path_1)
    image_id_2 = os.path.basename(image_path_2)

    image_1 = Image.open(image_path_1).convert('RGB')
    image_2 = Image.open(image_path_2).convert('RGB')

    # Preprocess the images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    image_1 = transform(image_1)
    image_2 = transform(image_2)

    # Stack the images along a new dimension
    image = torch.stack([image_1, image_2], dim=0)

    # Add batch dimension
    image = image.unsqueeze(0)

    # Move the input image to the same device as the model
    device = next(model.parameters()).device
    image = image.to(device)

    # Print the image IDs
    print(f"Image ID 1: {image_id_1}")
    print(f"Image ID 2: {image_id_2}")

    # Generate the report
    with torch.no_grad():
        output, _ = model(image, mode='sample')

    generated_report = model.tokenizer.decode_batch(output.cpu().detach().numpy())[0]

    # Print the generated report
    print(f"Generated Report for Image ID 1: {generated_report}")


if __name__ == '__main__':
    main()





Image ID 1: 0.png
Image ID 2: 1.png


RuntimeError: ignored

here we check the pre-trained model accuracy

In [None]:
#tester.py
import logging
import os
from abc import abstractmethod

import cv2
import numpy as np
import pandas as pd
import spacy
import torch
from tqdm import tqdm




class BaseTester(object):
    def __init__(self, model, criterion, metric_ftns, args):
        self.args = args

        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(args['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns

        self.epochs = self.args['epochs']
        self.save_dir = self.args['save_dir']
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self._load_checkpoint(args['load'])

    @abstractmethod
    def test(self):
        raise NotImplementedError

    @abstractmethod
    def plot(self):
        raise NotImplementedError

    def _prepare_device(self, n_gpu_use):
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
                    n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _load_checkpoint(self, load_path):
        load_path = str(load_path)
        self.logger.info("Loading checkpoint: {} ...".format(load_path))
        checkpoint = torch.load(load_path)
        self.model.load_state_dict(checkpoint['state_dict'])


class Tester(BaseTester):
    def __init__(self, model, criterion, metric_ftns, args, test_dataloader):
        super(Tester, self).__init__(model, criterion, metric_ftns, args)
        self.test_dataloader = test_dataloader

    def test(self):
        self.logger.info('Start to evaluate in the test set.')
        self.model.eval()
        log = dict()
        with torch.no_grad():
            test_gts, test_res = [], []
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
                ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
                test_res.extend(reports)
                test_gts.extend(ground_truths)

            test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
                                        {i: [re] for i, re in enumerate(test_res)})
            log.update(**{'test_' + k: v for k, v in test_met.items()})
            print(log)

            test_res, test_gts = pd.DataFrame(test_res), pd.DataFrame(test_gts)
            test_res.to_csv(os.path.join(self.save_dir, "res.csv"), index=False, header=False)
            test_gts.to_csv(os.path.join(self.save_dir, "gts.csv"), index=False, header=False)

        return log

    def plot(self):
        assert self.args['batch_size'] == 1 and self.args['beam_size'] == 1
        self.logger.info('Start to plot attention weights in the test set.')
        os.makedirs(os.path.join(self.save_dir, "attentions"), exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, "attentions_entities"), exist_ok=True)
        ner = spacy.load("en_core_sci_sm")
        mean = torch.tensor((0.485, 0.456, 0.406))
        std = torch.tensor((0.229, 0.224, 0.225))
        mean = mean[:, None, None]
        std = std[:, None, None]

        self.model.eval()
        with torch.no_grad():
            for batch_idx, (images_id, images, reports_ids, reports_masks) in tqdm(enumerate(self.test_dataloader)):
                images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
                    self.device), reports_masks.to(self.device)
                output, _ = self.model(images, mode='sample')
                image = torch.clamp((images[0].cpu() * std + mean) * 255, 0, 255).int().cpu().numpy()
                report = self.model.tokenizer.decode_batch(output.cpu().numpy())[0].split()

                char2word = [idx for word_idx, word in enumerate(report) for idx in [word_idx] * (len(word) + 1)][:-1]

                attention_weights = self.model.encoder_decoder.attention_weights[:-1]
                assert len(attention_weights) == len(report)
                for word_idx, (attns, word) in enumerate(zip(attention_weights, report)):
                    for layer_idx, attn in enumerate(attns):
                        os.makedirs(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)

                        heatmap = generate_heatmap(image, attn.mean(1).squeeze())
                        cv2.imwrite(os.path.join(self.save_dir, "attentions", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx), "{:04d}_{}.png".format(word_idx, word)),
                                    heatmap)

                for ne_idx, ne in enumerate(ner(" ".join(report)).ents):
                    for layer_idx in range(len(attention_weights[0])):
                        os.makedirs(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx)), exist_ok=True)
                        attn = [attns[layer_idx] for attns in
                                attention_weights[char2word[ne.start_char]:char2word[ne.end_char] + 1]]
                        attn = np.concatenate(attn, axis=2)
                        heatmap = generate_heatmap(image, attn.mean(1).mean(1).squeeze())
                        cv2.imwrite(os.path.join(self.save_dir, "attentions_entities", "{:04d}".format(batch_idx),
                                                 "layer_{}".format(layer_idx), "{:04d}_{}.png".format(ne_idx, ne)),
                                    heatmap)


In [None]:
import argparse
import numpy as np
import torch


def main():
    args = {
        'image_dir': '/content/drive/MyDrive/iu_xray/images',
        'ann_path': '/content/drive/MyDrive/iu_xray/annotation.json',
        'dataset_name': 'iu_xray',
        'max_seq_length': 60,
        'threshold': 3,
        'num_workers': 2,
        'batch_size': 16,
        'visual_extractor': 'resnest200e',
        'visual_extractor_pretrained': True,
        'd_model': 512,
        'd_ff': 512,
        'd_vf': 2048,
        'num_heads': 8,
        'num_layers': 3,
        'dropout': 0.1,
        'logit_layers': 1,
        'bos_idx': 0,
        'eos_idx': 0,
        'pad_idx': 0,
        'use_bn': 0,
        'drop_prob_lm': 0.5,
        'topk': 32,
        'cmm_size': 2048,
        'cmm_dim': 512,
        'sample_method': 'beam_search',
        'beam_size': 3,
        'temperature': 1.0,
        'sample_n': 1,
        'group_size': 1,
        'output_logsoftmax': 1,
        'decoding_constraint': 0,
        'block_trigrams': 1,
        'n_gpu': 1,
        'epochs': 100,
        'save_dir': 'results/iu_xray',
        'record_dir': 'records/',
        'log_period': 1000,
        'save_period': 1,
        'monitor_mode': 'max',
        'monitor_metric': 'BLEU_4',
        'early_stop': 50,
        'optim': 'Adam',
        'lr_ve': 5e-5,
        'lr_ed': 7e-4,
        'weight_decay': 5e-5,
        'adam_betas': (0.9, 0.98),
        'adam_eps': 1e-9,
        'amsgrad': True,
        'noamopt_warmup': 5000,
        'noamopt_factor': 1,
        'lr_scheduler': 'StepLR',
        'step_size': 50,
        'gamma': 0.1,
        'seed': 9233,
        'resume': None,
        # 'load': "/content/results/iu_xray/model_best.pth",
        'load': "/content/drive/MyDrive/resnest200e_save_checkpoints/model_best.pth",

    }

    # fix random seeds
    torch.manual_seed(args['seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args['seed'])

    # create tokenizer
    tokenizer = Tokenizer(args)

    # create data loader
    test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)

    # build model architecture
    model = BaseCMNModel(args, tokenizer)

    # get function handles of loss and metrics
    criterion = compute_loss
    metrics = compute_scores

    # build trainer and start to train
    tester = Tester(model, criterion, metrics, args, test_dataloader)
    tester.test()


if __name__ == '__main__':
    main()



bellow the visual extractor
Visual output size: 2048
linear_input_size: 2048
bellow the linear layer




in basecmnmodel bellow the encode_decoder
in basecmnmodel iu xray forward pass


RuntimeError: ignored