In [1]:
import torch
from models.models import BaseCMNModel
import argparse

In [2]:
def parse_agrs():
    parser = argparse.ArgumentParser()

    # Data input settings
    parser.add_argument('--image_dir', type=str, default='F:/radiologyReportGeneration/datasets/iu_xray/images/',
                        help='the path to the directory containing the data.')
    parser.add_argument('--ann_path', type=str, default='F:/radiologyReportGeneration/datasets/iu_xray/annotation.json',
                        help='the path to the directory containing the data.')

    # IDAM
    parser.add_argument("--useIDAM", action="store_false", help="do you use IDAM?")

    # Mamba
    parser.add_argument("--useVTFCM", action="store_false", help="do you use MAM?")

    # Data loader settings
    parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr'],
                        help='the dataset to be used.')
    parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
    parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
    parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.')
    parser.add_argument('--batch_size', type=int, default=2, help='the number of samples for a batch')

    # Model settings (for visual extractor)
    parser.add_argument('--visual_extractor', type=str, default='resnet101', help='the visual extractor to be used.')
    parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')

    # Model settings (for Transformer)
    parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
    parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
    parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
    parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
    parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
    parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
    parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
    parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
    parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
    parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
    parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
    parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')

    # for Cross-modal Memory
    parser.add_argument('--topk', type=int, default=32, help='the number of k.')
    parser.add_argument('--cmm_size', type=int, default=2048, help='the numebr of cmm size.')
    parser.add_argument('--cmm_dim', type=int, default=512, help='the dimension of cmm dimension.')

    # Sample related
    parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
    parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
    parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
    parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
    parser.add_argument('--group_size', type=int, default=1, help='the group size.')
    parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
    parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
    parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')

    # Trainer settings
    parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
    parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
    parser.add_argument('--save_dir', type=str, default='results/iu_xray/', help='the patch to save the models.')
    parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments.')
    parser.add_argument('--log_period', type=int, default=50, help='the logging interval (in batches).')
    parser.add_argument('--save_period', type=int, default=10, help='the saving period (in epochs).')
    parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
    parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
    parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')

    # Optimization
    parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
    parser.add_argument('--lr_ve', type=float, default=1e-4, help='the learning rate for the visual extractor.')
    parser.add_argument('--lr_ed', type=float, default=5e-4, help='the learning rate for the remaining parameters.')
    parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
    parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.')
    parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.')
    parser.add_argument('--amsgrad', type=bool, default=True, help='.')
    parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.')
    parser.add_argument('--noamopt_factor', type=int, default=1, help='.')

    # Learning Rate Scheduler
    parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
    parser.add_argument('--step_size', type=int, default=10, help='the step size of the learning rate scheduler.')
    parser.add_argument('--gamma', type=float, default=0.8, help='the gamma of the learning rate scheduler.')

    # Others
    parser.add_argument('--seed', type=int, default=9233, help='.')
    parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')

    args = parser.parse_args()
    return args

In [4]:
weight = torch.load("F:/radiologyReportGeneration/result/CMN+IFAM+CS/model_best_1.pth")
weight.keys()

dict_keys(['epoch', 'state_dict', 'optimizer', 'monitor_best'])

In [7]:
model_weight = weight.get('state_dict')

In [8]:
model_weight.keys()

odict_keys(['visual_extractor.model.0.weight', 'visual_extractor.model.1.weight', 'visual_extractor.model.1.bias', 'visual_extractor.model.1.running_mean', 'visual_extractor.model.1.running_var', 'visual_extractor.model.1.num_batches_tracked', 'visual_extractor.model.4.0.conv1.weight', 'visual_extractor.model.4.0.bn1.weight', 'visual_extractor.model.4.0.bn1.bias', 'visual_extractor.model.4.0.bn1.running_mean', 'visual_extractor.model.4.0.bn1.running_var', 'visual_extractor.model.4.0.bn1.num_batches_tracked', 'visual_extractor.model.4.0.conv2.weight', 'visual_extractor.model.4.0.bn2.weight', 'visual_extractor.model.4.0.bn2.bias', 'visual_extractor.model.4.0.bn2.running_mean', 'visual_extractor.model.4.0.bn2.running_var', 'visual_extractor.model.4.0.bn2.num_batches_tracked', 'visual_extractor.model.4.0.conv3.weight', 'visual_extractor.model.4.0.bn3.weight', 'visual_extractor.model.4.0.bn3.bias', 'visual_extractor.model.4.0.bn3.running_mean', 'visual_extractor.model.4.0.bn3.running_var', 

In [12]:
model_weight.get("ImageFA")

In [42]:
import torch.nn as nn
mask = nn.Parameter(torch.normal(mean=1,std=0.05,size=(49, 49)), requires_grad=False)
mask_v = torch.randn(49, 49)

In [43]:
mask = nn.Softmax()(mask + mask_v)

  return self._call_impl(*args, **kwargs)


In [44]:
mask[:32, -32:] = 0

In [46]:
print(list(mask.flatten().numpy().round(2)*100))

[8.0, 4.0, 0.0, 4.0, 2.0, 1.0, 2.0, 1.0, 3.0, 1.0, 3.0, 3.0, 4.0, 2.0, 1.0, 2.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 1.0, 4.0, 2.0, 6.0, 0.0, 2.0, 5.0, 1.0, 7.0, 0.0, 7.0, 2.0, 1.0, 2.0, 5.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 7.0, 1.0, 4.0, 1.0, 9.0, 3.0, 0.0, 1.0, 1.0, 2.0, 4.0, 2.0, 9.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 2.0, 7.0, 1.0, 1.0, 3.0, 13.0, 2.0, 2.0, 2.0, 1.0, 1.0, 6.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0