In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# os.makedirs('/kaggle/working/checkpoints/')

In [None]:

# --- Base packages ---
import os
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics as metrics

# --- PyTorch packages ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

from tqdm import tqdm
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# --- Helper packages ---
from random import shuffle
import sentencepiece as spm
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# --- Datasets ---
class NLMCXR(data.Dataset): # Open-I Dataset
    def __init__(self, directory, input_size=(256,256), random_transform=True,
                view_pos=['AP', 'PA', 'LATERAL'], max_views=2, sources=['image','history'], targets=['label'], 
                max_len=1000, vocab_file='report-generation-support/nlmcxr_unigram_1000.model'):
        
        self.source_sections = ['INDICATION', 'COMPARISON']
        self.target_sections = ['FINDINGS']
        self.vocab = spm.SentencePieceProcessor(model_file=directory + vocab_file)
        self.vocab_file = vocab_file # Save it for subsets

        self.sources = sources # Choose which section as input
        self.targets = targets # Choose which section as output
        self.max_views = max_views
        self.view_pos = view_pos
        self.max_len = max_len

        self.dir = directory
        self.input_size = input_size
        self.random_transform = random_transform
        self.__input_data(binary_mode=True)
        
        if random_transform:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.1,0.1,0.1), 
                    transforms.RandomRotation(15, expand=True)]),
                transforms.Resize(input_size),
                transforms.ToTensor(),
            ])
        else:
            self.transform = transforms.Compose([transforms.Resize(input_size), transforms.ToTensor()])
    
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        sources, targets = [], []
        tmp_rep = self.captions[self.file_report[file_name]['image'][0] + '.png']
        
        # ------ Multiview Images ------
        if 'image' in self.sources:
            imgs, vpos = [], []
            images = self.file_report[file_name]['image']
            # Randomly select V images from each folder 
            new_orders = np.random.permutation(len(images))
            img_files = np.array(images)[new_orders].tolist()
#             print(images)
            for i in range(min(self.max_views,len(img_files))):
                if img_files[i] not in ['CXR1_1_IM-0001-4001','CXR1_1_IM-0001-3001']:
                    img_file = os.path.join(self.dir, 'chest-xrays-indiana-university', 'images', 'images_normalized', img_files[i][3:] + '.dcm.png')
                else:
                    img_file = os.path.join(self.dir, 'chest-xrays-indiana-university', 'images', 'images_normalized', img_files[i][5:] + '.dcm.png')
                    #self.dir + img_files[i] + '.png'
                img = Image.open(img_file).convert('RGB')
                imgs.append(self.transform(img).unsqueeze(0)) # (1,C,W,H)
                vpos.append(1) # We do not know what view position of the image is, so just let it be 1
                
            # If the number of images is smaller than V, pad the tensor with dummy images
            cur_len = len(vpos)
            for i in range(cur_len, self.max_views):
                imgs.append(torch.zeros_like(imgs[0]))
                vpos.append(-1) # Empty mask
            
            imgs = torch.cat(imgs, dim=0) # (V,C,W,H)
            vpos = np.array(vpos, dtype=np.int64) # (V)

        # ------ Additional Information ------
        info = self.file_report[file_name]['report']
        
        source_info = []
        for section, content in info.items():
            if section in self.source_sections:
                source_info.append(content)
        source_info = ' '.join(source_info)
        
        encoded_source_info = [self.vocab.bos_id()] + self.vocab.encode(source_info) + [self.vocab.eos_id()]
        source_info = np.ones(self.max_len, dtype=np.int64) * self.vocab.pad_id()
        source_info[:min(len(encoded_source_info), self.max_len)] = encoded_source_info[:min(len(encoded_source_info), self.max_len)]

        target_info = []
        for section, content in info.items():
            if section in self.target_sections:
                target_info.append(content)
        # target_info = ' '.join(target_info)
        target_info = tmp_rep # This load the document from our previous AAAI paper (preprocessed documents)
        
        np_labels = np.zeros(len(self.top_np), dtype=float)
        for i in range(len(self.top_np)):
            if self.top_np[i] in target_info:
                np_labels[i] = 1
        
        encoded_target_info = [self.vocab.bos_id()] + self.vocab.encode(target_info) + [self.vocab.eos_id()]
        target_info = np.ones(self.max_len, dtype=np.int64) * self.vocab.pad_id()
        target_info[:min(len(encoded_target_info), self.max_len)] = encoded_target_info[:min(len(encoded_target_info), self.max_len)]

        for i in range(len(self.sources)):
            if self.sources[i] == 'image':
                sources.append((imgs,vpos))
            if self.sources[i] == 'history':
                sources.append(source_info)
            if self.sources[i] == 'label':
                sources.append(np.concatenate([np.array(self.file_labels[file_name]), np_labels]))
            if self.sources[i] == 'caption':
                sources.append(target_info)
            if self.sources[i] == 'caption_length':
                sources.append(min(len(encoded_target_info), self.max_len))
                
        for i in range(len(self.targets)):
            if self.targets[i] == 'label':
                targets.append(np.concatenate([np.array(self.file_labels[file_name]), np_labels]))
            if self.targets[i] == 'caption':
                targets.append(target_info)
            if self.targets[i] == 'caption_length':
                targets.append(min(len(encoded_target_info), self.max_len))
                
        return sources if len(sources) > 1 else sources[0], targets if len(targets) > 1 else targets[0]

    def __get_nounphrase(self, top_k=100, file_name='report-generation-support/count_nounphrase.json'):
        count_np = json.load(open(self.dir + file_name, 'r'))
        sorted_count_np = sorted([(k,v) for k,v in count_np.items()], key=lambda x: x[1], reverse=True)
        top_nounphrases = [k for k,v in sorted_count_np][:top_k]
        return top_nounphrases

    def __input_data(self, binary_mode=True):
        self.__input_caption()
        self.__input_report()
        self.__input_label()
        self.__filter_inputs()
        self.top_np = self.__get_nounphrase()
        
    def __input_label(self):
        with open(self.dir + 'report-generation-support/file2label.json') as f:
            labels = json.load(f)
        self.file_labels = labels
        
    def __input_caption(self):
        with open(self.dir + 'report-generation-support/captions.json') as f:
            captions = json.load(f)
        self.captions = captions
        
    def __input_report(self):
        with open(self.dir + 'report-generation-support/reports_ori.json') as f:
            reports = json.load(f)
        self.file_list = [k for k in reports.keys()]
        self.file_report = reports

    def __filter_inputs(self):
        filtered_file_report = {}
        for k, v in self.file_report.items():
            if (len(v['image']) > 0) and (('FINDINGS' in v['report']) and (v['report']['FINDINGS'] != '')): # or (('IMPRESSION' in v['report']) and (v['report']['IMPRESSION'] != ''))):
                filtered_file_report[k] = v
        self.file_report = filtered_file_report
        self.file_list = [k for k in self.file_report.keys()]

    def get_subsets(self, train_size=0.7, val_size=0.1, test_size=0.2, seed=0):
        np.random.seed(seed)
        indices = np.random.permutation(len(self.file_list))
        train_pvt = int(train_size * len(self.file_list))
        val_pvt = int((train_size + val_size) * len(self.file_list))
        train_indices = indices[:train_pvt]
        val_indices = indices[train_pvt:val_pvt]
        test_indices = indices[val_pvt:]

        master_file_list = np.array(self.file_list)

        train_dataset = NLMCXR(self.dir, self.input_size, self.random_transform, 
                              self.view_pos, self.max_views, self.sources, self.targets, self.max_len, self.vocab_file)
        train_dataset.file_list = master_file_list[train_indices].tolist()

        # Consider change random_transform to False for validation
        val_dataset = NLMCXR(self.dir, self.input_size, False, 
                            self.view_pos, self.max_views, self.sources, self.targets, self.max_len, self.vocab_file)
        val_dataset.file_list = master_file_list[val_indices].tolist()

        # Consider change random_transform to False for testing
        test_dataset = NLMCXR(self.dir, self.input_size, False, 
                             self.view_pos, self.max_views, self.sources, self.targets, self.max_len, self.vocab_file)
        test_dataset.file_list = master_file_list[test_indices].tolist()

        return train_dataset, val_dataset, test_dataset

In [None]:


# ------ Helper Functions ------
def data_to_device(data, device='cpu'):
	if isinstance(data, torch.Tensor):
		data = data.to(device)
	elif isinstance(data, tuple):
		data = tuple(data_to_device(item,device) for item in data)
	elif isinstance(data, list):
		data = list(data_to_device(item,device) for item in data)
	elif isinstance(data, dict):
		data = dict((k,data_to_device(v,device)) for k,v in data.items())
	else:
		raise TypeError('Unsupported Datatype! Must be a Tensor/List/Tuple/Dict.')
	return data

def data_concatenate(iterable_data, dim=0):
	data = iterable_data[0] # can be a list / tuple / dict / tensor
	if isinstance(data, torch.Tensor):
		return torch.cat([*iterable_data], dim=dim)
	elif isinstance(data, tuple):
		num_cols = len(data)
		num_rows = len(iterable_data)
		return_data = []
		for col in range(num_cols):
			data_col = []
			for row in range(num_rows):
				data_col.append(iterable_data[row][col])
			return_data.append(torch.cat([*data_col], dim=dim))
		return tuple(return_data)
	elif isinstance(data, list):
		num_cols = len(data)
		num_rows = len(iterable_data)
		return_data = []
		for col in range(num_cols):
			data_col = []
			for row in range(num_rows):
				data_col.append(iterable_data[row][col])
			return_data.append(torch.cat([*data_col], dim=dim))
		return list(return_data)
	elif isinstance(data, dict):
		num_cols = len(data)
		num_rows = len(iterable_data)
		return_data = []
		for col in data.keys():
			data_col = []
			for row in range(num_rows):
				data_col.append(iterable_data[row][col])
			return_data.append(torch.cat([*data_col], dim=dim))
		return dict((k,return_data[i]) for i,k in enumerate(data.keys()))
	else:
		raise TypeError('Unsupported Datatype! Must be a Tensor/List/Tuple/Dict.')

def data_distributor(model, source):
	if isinstance(source, torch.Tensor):
		output = model(source)
	elif isinstance(source, tuple) or isinstance(source, list):
		output = model(*source)
	elif isinstance(source, dict):
		output = model(**source)
	else:
		raise TypeError('Unsupported DataType! Try List/Tuple!')
	return output
	
def args_to_kwargs(args, kwargs_list=None): # This function helps distribute input to corresponding arguments in Torch models
	if kwargs_list != None:
		if isinstance(args, dict): # Nothing to do here
			return args 
		else: # args is a list or tuple or single element
			if isinstance(args, torch.Tensor): # single element
				args = [args]
			assert len(args) == len(kwargs_list)
			return dict(zip(kwargs_list, args))
	else: # Nothing to do here
		return args

In [None]:
class TextDataset(data.Dataset):
    def __init__(self, text_file, label_file, sources=['caption'], targets=['label'],
                 vocab_file='/kaggle/input/report-generation-support/nlmcxr_unigram_1000.model', max_len=1000):
        self.text_file = text_file
        self.label_file = label_file
        self.vocab = spm.SentencePieceProcessor(model_file=vocab_file)
        self.sources = sources # Choose which section as input
        self.targets = targets # Choose which section as output
        self.max_len = max_len
        self.__input_data()

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        encoded_text = [self.vocab.bos_id()] + self.vocab.encode(self.lines[idx].strip()) + [self.vocab.eos_id()]
        text = np.ones(self.max_len, dtype=np.int64) * self.vocab.pad_id()
        text[:min(len(encoded_text), self.max_len)] = encoded_text[:min(len(encoded_text), self.max_len)]
        
        sources = []
        for i in range(len(self.sources)):
            if self.sources[i] == 'label':
                sources.append(self.labels[idx])
            if self.sources[i] == 'caption':
                sources.append(text)
            if self.sources[i] == 'caption_length':
                sources.append(min(len(encoded_text), self.max_len))
        
        targets = []
        for i in range(len(self.targets)):
            if self.targets[i] == 'label':
                targets.append(self.labels[idx])
            if self.targets[i] == 'caption':
                targets.append(text)
            if self.targets[i] == 'caption_length':
                targets.append(min(len(encoded_text), self.max_len))
                
        return sources if len(sources) > 1 else sources[0], targets if len(targets) > 1 else targets[0]
    
    def __input_data(self):
        data_file = open(self.text_file, 'r') 
        self.lines = data_file.readlines()
        self.labels = np.loadtxt(self.label_file, dtype='float')

In [None]:
# ------ Core Functions ------
def train(data_loader, model, optimizer, criterion, scheduler=None, device='cpu', kw_src=None, kw_tgt=None, kw_out=None, scaler=None):
	model.train()
	running_loss = 0
 
	prog_bar = tqdm(data_loader)
	for i, (source, target) in enumerate(prog_bar):
		source = data_to_device(source, device)
		target = data_to_device(target, device)

		source = args_to_kwargs(source, kw_src)
		target = args_to_kwargs(target, kw_tgt)

		if scaler != None:
			with torch.cuda.amp.autocast():
				output = data_distributor(model, source)
				output = args_to_kwargs(output, kw_out)
				loss = criterion(output, target)
				
			running_loss += loss.item()
			prog_bar.set_description('Loss: {}'.format(running_loss/(i+1)))

			# Back-propagate and update weights
			optimizer.zero_grad()
			scaler.scale(loss).backward()
			scaler.step(optimizer)
			scaler.update()
			if scheduler != None:
				scheduler.step()
		else:
			output = data_distributor(model, source)
			output = args_to_kwargs(output, kw_out)
			loss = criterion(output, target)

			running_loss += loss.item()
			prog_bar.set_description('Loss: {}'.format(running_loss/(i+1)))

			# Back-propagate and update weights
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			if scheduler != None:
				scheduler.step()

	return running_loss / len(data_loader)

def test(data_loader, model, criterion=None, device='cpu', return_results=True, kw_src=None, kw_tgt=None, kw_out=None, select_outputs=[]):
	model.eval()
	running_loss = 0

	outputs = []
	targets = []

	with torch.no_grad():
		prog_bar = tqdm(data_loader)
		for i, (source, target) in enumerate(prog_bar):
			source = data_to_device(source, device)
			target = data_to_device(target, device)

			source = args_to_kwargs(source, kw_src)
			target = args_to_kwargs(target, kw_tgt)

			output = data_distributor(model, source)
			output = args_to_kwargs(output, kw_out)

			if criterion != None:
				loss = criterion(output, target)
				running_loss += loss.item()
			prog_bar.set_description('Loss: {}'.format(running_loss/(i+1)))

			if return_results:
				if len(select_outputs) == 0:
					outputs.append(data_to_device(output,'cpu'))
					targets.append(data_to_device(target,'cpu'))
				else:
					list_output = [output[row] for row in select_outputs]
					list_target = [target[row] for row in select_outputs]
					outputs.append(data_to_device(list_output if len(list_output) > 1 else list_output[0],'cpu'))
					targets.append(data_to_device(list_target if len(list_target) > 1 else list_target[0],'cpu'))
	
	if return_results:
		outputs = data_concatenate(outputs)
		targets = data_concatenate(targets)
		return running_loss / len(data_loader), outputs, targets
	else:
		return running_loss / len(data_loader)

def save(path, model, optimizer=None, scheduler=None, epoch=-1, stats=None):
	torch.save({
		# --- Model Statistics ---
		'epoch': epoch,
		'stats': stats,
		# --- Model Parameters ---
		'model_state_dict': model.state_dict(),
		'optimizer_state_dict': optimizer.state_dict() if optimizer != None else None,
		'scheduler_state_dict': scheduler.state_dict() if scheduler != None else None,
	}, path)

def load(path, model, optimizer=None, scheduler=None):
	checkpoint = torch.load(path)
	# --- Model Statistics ---
	epoch = checkpoint['epoch']
	stats = checkpoint['stats']
	# --- Model Parameters ---
	model.load_state_dict(checkpoint['model_state_dict'])
	if optimizer != None:
		try:
			optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
		except: # Input optimizer doesn't fit the checkpoint one --> should be ignored
			print('Cannot load the optimizer')
	if scheduler != None:
		try:
			scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
		except: # Input scheduler doesn't fit the checkpoint one --> should be ignored
			print('Cannot load the scheduler')
	return epoch, stats

In [None]:


class KLLoss(nn.Module):
	def __init__(self):
		super().__init__()
		self.KLLoss = nn.KLDivLoss()

	def forward(self, output, target):
		'''
		Output: (N,*) \n
		Target: (N,*) \n
		'''
		output = torch.log(output)  # Invert softmax
		# target = torch.log(target) # Invert softmax
		# How output distribution differs from target distribution
		return self.KLLoss(output, target)


class CELoss(nn.Module):
	def __init__(self, ignore_index=-1):
		super().__init__()
		self.CELoss = nn.CrossEntropyLoss(ignore_index=ignore_index)

	def forward(self, output, target):
		'''
		Output: (N,*,C) \n
		Target: (N,*) \n
		'''
		output = torch.log(output)  # Invert softmax
		output = output.reshape(-1, output.shape[-1])  # (*,C)
		target = target.reshape(-1).long()  # (*)
		return self.CELoss(output, target)


class CELossSame(nn.Module):
	def __init__(self, ignore_index=-1):
		super().__init__()
		self.CELoss = nn.CrossEntropyLoss(ignore_index=ignore_index)

	def forward(self, outputs, target):
		'''
		Output: (N,*,C) \n
		Target: (N,*) \n
		'''
		output_img = torch.log(outputs[0]) # Invert softmax
		output_txt = torch.log(outputs[1])
		output_sen = torch.log(outputs[2])

		output_img = output_img.reshape(-1, output_img.shape[-1]) # (*,C)
		output_txt = output_txt.reshape(-1, output_txt.shape[-1]) # (*,C)
		output_sen = output_sen.reshape(-1, output_sen.shape[-1]) # (*,C)
		target = target.reshape(-1).long() # (*)
		return self.CELoss(output_img, target) + self.CELoss(output_txt, target) + self.CELoss(output_sen, target)

class CELossShift(nn.Module):
	def __init__(self, ignore_index=-1):
		super().__init__()
		self.CELoss = CELoss(ignore_index=ignore_index)

	def forward(self, output, target):
		'''
		Output: (N,*,C) \n
		Target: (N,*) \n
		'''
		output = output[:,:-1,:] # (* - 1,C)
		target = target[:,1:] # (* - 1)
		return self.CELoss(output, target)

class CELossTotal(nn.Module):
	def __init__(self, ignore_index=-1):
		super().__init__()
		self.CELoss = CELoss()
		self.CELossShift = CELossShift(ignore_index=ignore_index)

	def forward(self, output, target):
		return self.CELossShift(output[0], target[0]) + self.CELoss(output[1], target[1])

class CELossTotalEval(nn.Module):
	def __init__(self, ignore_index=-1):
		super().__init__()
		self.CELoss = CELoss()
		self.CELossShift = CELossShift(ignore_index=ignore_index)

	def forward(self, output, target):
		return self.CELossShift(output[0], target[0]) + self.CELoss(output[1], target[1]) + self.CELoss(output[2], target[1])

class CELossTransfer(nn.Module):
	def __init__(self, ignore_index=-1):
		super().__init__()
		self.CELoss = CELoss()
		self.CELossShift = CELossShift(ignore_index=ignore_index)

	def forward(self, output, target):
		return self.CELossShift(output[0], target[0]) # + self.CELoss(output[1], target[1])

In [None]:

from torch.nn.utils.rnn import pack_padded_sequence

class Transformer(nn.Module):
    def __init__(self, image_encoder, num_tokens, num_posits, fc_features=1024, embed_dim=256, num_heads=8, fwd_dim=4096, dropout=0.1, num_layers_enc=1, num_layers_dec=6, freeze_encoder=True):

        super().__init__()
        self.token_embedding = nn.Embedding(num_tokens, embed_dim)
        self.posit_embedding = nn.Embedding(num_posits, embed_dim)
        self.pixel_embedding = nn.Embedding(64, embed_dim) # last convolution layer has 8x8 pixels = 64 pixels
        
        self.transformer_enc = nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(embed_dim,num_heads,fwd_dim,dropout), num_layers=num_layers_enc)
        self.transformer_dec = nn.TransformerDecoder(decoder_layer=nn.TransformerDecoderLayer(embed_dim,num_heads,fwd_dim,dropout), num_layers=num_layers_dec)
        
        self.fc1 = nn.Linear(fc_features, embed_dim)
        self.fc2 = nn.Linear(embed_dim, num_tokens)
        
        self.image_encoder = image_encoder # make sure that image_encoder is a MVCNN model
        if freeze_encoder: # The orginal paper freeze the densenet which is pretrained on ImageNet. Suprisingly, the results were very good
            for param in self.image_encoder.parameters():
                param.requires_grad = False
                
        self.dropout = nn.Dropout(dropout)
        self.num_tokens = num_tokens
        self.num_posits = num_posits

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def forward(self, image, caption = None, bos_id=1, eos_id=2, pad_id=3, max_len=300):
        if caption != None:
            avg_features, wxh_features = self.image_encoder(image) # (B,F), (B,F,W,H)
            wxh_features = wxh_features.view(wxh_features.shape[0], wxh_features.shape[1], -1).permute(0,2,1) # (B,W*H,F)
            wxh_features = self.fc1(wxh_features) # (B,W*H,E)

            pixel = torch.arange(wxh_features.shape[1]).unsqueeze(0).repeat(wxh_features.shape[0],1).to(wxh_features.device)
            pixel_embed = self.pixel_embedding(pixel) # (B,W*H,E)
            img_features = wxh_features + pixel_embed # (B,W*H,E)
            img_features = self.transformer_enc(img_features.transpose(0,1)).transpose(0,1) # (B,W*H,E)
            
            posit = torch.arange(caption.shape[1]).unsqueeze(0).repeat(caption.shape[0],1).to(caption.device) # (1,L) --> (B,L)
            posit_embed = self.posit_embedding(posit) # (B,L,E)
            token_embed = self.token_embedding(caption) # (B,L,E)
            cap_features = token_embed + posit_embed # (B,L,E)
            
            tgt_mask = self.generate_square_subsequent_mask(caption.shape[1]).to(caption.device)
            output = self.transformer_dec(tgt=cap_features.transpose(0,1), 
                                          memory=img_features.transpose(0,1),
                                          tgt_mask=tgt_mask,
                                          tgt_key_padding_mask=(caption == pad_id)).transpose(0, 1) # (L,B,E) -> (B,L,E)
            
            preds = self.fc2(self.dropout(output)) # (B,L,S)
            preds = torch.softmax(preds, dim = -1) # (B,L,S)
            return preds # (B,L,S)

        else:
            avg_features, wxh_features = self.image_encoder(image) # (B,F), (B,F,W,H)
            wxh_features = wxh_features.view(wxh_features.shape[0], wxh_features.shape[1], -1).permute(0,2,1) # (B,W*H,F)
            wxh_features = self.fc1(wxh_features) # (B,W*H,E)
            
            pixel = torch.arange(wxh_features.shape[1]).unsqueeze(0).repeat(wxh_features.shape[0],1).to(wxh_features.device)
            pixel_embed = self.pixel_embedding(pixel) # (B,W*H,E)
            img_features = wxh_features + pixel_embed # (B,W*H,E)
            img_features = self.transformer_enc(img_features.transpose(0,1)).transpose(0,1) # (B,W*H,E)
            
            caption = torch.ones((img_features.shape[0],1), dtype=torch.long).to(img_features.device) * bos_id # (B,1)
            for _ in range(max_len):
                posit = torch.arange(caption.shape[1]).unsqueeze(0).repeat(caption.shape[0],1).to(caption.device) # (1,L') --> (B,L')
                posit_embed = self.posit_embedding(posit) # (B,L',E)
                token_embed = self.token_embedding(caption) # (B,L',E)
                cap_features = token_embed + posit_embed # (B,L',E)

                tgt_mask = self.generate_square_subsequent_mask(caption.shape[1]).to(caption.device)
                output = self.transformer_dec(tgt=cap_features.transpose(0,1), 
                                              memory=img_features.transpose(0,1),
                                              tgt_mask=tgt_mask,
                                              tgt_key_padding_mask=(caption == pad_id)).transpose(0, 1) # (L',B,E) -> (B,L',E)
                
                preds = self.fc2(self.dropout(output)) # (B,L',S)
                preds = torch.softmax(preds, dim = -1) # (B,L',S)
                preds = torch.argmax(preds[:,-1,:], dim=-1, keepdim=True) # (B,1)
                caption = torch.cat([caption, preds], dim=-1) # (B,L'+1)
            
            return caption # (B,L')
        
class GumbelTransformer(nn.Module):
    def __init__(self, transformer, diff_chexpert, freeze_chexpert=True):
        super().__init__()
        self.transformer = transformer
        self.diff_chexpert = diff_chexpert
        if freeze_chexpert:
            for param in self.diff_chexpert.parameters():
                param.requires_grad = False

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def apply_chexpert(self, embed, caption_length):
        padding_mask = self.diff_chexpert.generate_pad_mask(embed.shape[0], embed.shape[1], caption_length)
        output, (_, _) = self.diff_chexpert.rnn(embed)
        y_hats = [attn(output, padding_mask) for attn in self.diff_chexpert.attns]
        y_hats = torch.stack(y_hats, dim=1)
        y_hats = torch.softmax(y_hats, dim=-1)
        return y_hats
        
    def forward(self, image, caption = None, caption_length=None, bos_id=1, eos_id=2, pad_id=3, max_len=300, temperature=1, beta=1):
        if caption != None:
            preds = self.transformer(image, caption, bos_id, eos_id, pad_id, max_len) # (B,L,S)
            
            logits = torch.log(preds) # (B,L,S)            
            one_hot_preds = self.gumbel_softmax_sample(logits,temperature,beta) # (B,L,S)
            vocab = torch.arange(self.transformer.num_tokens).unsqueeze(0).repeat(caption.shape[0],1).to(caption.device) # (1,S) --> (B,S)
            vocab_embed = self.transformer.token_embedding(vocab) # (B,S,E)
            preds_embed = one_hot_preds @ vocab_embed # (B,L,S) x (B,S,E) = (B,L,E)
            chexpert_preds = self.apply_chexpert(preds_embed, caption_length) # (B,D,C)
            return preds, chexpert_preds # (B,L,S), (B,D,C)

        else:
            caption = self.transformer(image, caption, bos_id, eos_id, pad_id, max_len) # (B,L')
            return caption # (B,L')
        
    def sample_gumbel(self, shape, device, eps=1e-20):
        U = torch.rand(shape).to(device)
        return -torch.log(-torch.log(U + eps) + eps)
    
    def gumbel_softmax_sample(self, logits, temperature, beta):
        y = logits + beta * self.sample_gumbel(logits.size(), logits.device)
        return torch.softmax(y / temperature, dim=-1)
    
# --- CheXpert ---
class TanhAttention(nn.Module):
    def __init__(self, hidden_size, dropout=0.5, num_out=2):
        super(TanhAttention, self).__init__()
        self.attn1 = nn.Linear(hidden_size, hidden_size // 2)
        self.attn2 = nn.Linear(hidden_size // 2, 1, bias=False)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_size, num_out)

    def forward(self, output, mask):
        attn1 = nn.Tanh()(self.attn1(output))
        attn2 = self.attn2(attn1).squeeze(-1)
        attn = F.softmax(torch.add(attn2, mask), dim=1)

        h = output.transpose(1, 2).matmul(attn.unsqueeze(2)).squeeze(2)
        y_hat = self.fc(self.dropout(h))
        
        return y_hat

class DotAttention(nn.Module):
    def __init__(self, hidden_size, dropout=0.5, num_out=2):
        super(DotAttention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(hidden_size, 1, bias=False)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_size, num_out)

    def forward(self, output, mask):
        attn = (self.attn(output) / (self.hidden_size ** 0.5)).squeeze(-1)
        attn = F.softmax(torch.add(attn, mask), dim=1)

        h = output.transpose(1, 2).matmul(attn.unsqueeze(2)).squeeze(2)
        y_hat = self.fc(self.dropout(h))

        return y_hat
    
class LSTM_Attn(nn.Module):
    def __init__(self, num_tokens, embed_dim, hidden_size, num_topics, num_states, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(num_tokens, embed_dim)
        self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, batch_first=True, bidirectional=True)
        self.attns = nn.ModuleList([TanhAttention(hidden_size*2, dropout, num_states) for i in range(num_topics)])

    def generate_pad_mask(self, batch_size, max_len, caption_length):
        mask = torch.full((batch_size, max_len), fill_value=float('-inf'), dtype=torch.float, device='cuda')
        for ind, cap_len in enumerate(caption_length):
            mask[ind][:cap_len] = 0
        return mask

    def forward(self, caption, caption_length):
        x = self.embed(caption) # (B,L,E)
        output, (_,_) = self.rnn(x)

        padding_mask = self.generate_pad_mask(caption.shape[0], caption.shape[1], caption_length)

        y_hats = [attn(output, padding_mask) for attn in self.attns]
        y_hats = torch.stack(y_hats, dim=1)
        y_hats = torch.softmax(y_hats, dim=-1)
        return y_hats

class CNN_Attn(nn.Module):
    def __init__(self, embed_weight, emb_dim, filters, kernels, num_classes=14):

        super(CNN_Attn, self).__init__()

        self.embed = nn.Embedding.from_pretrained(torch.from_numpy(embed_weight), freeze=True)

        self.Ks = kernels

        self.convs = nn.ModuleList([nn.Conv1d(emb_dim, filters, K) for K in self.Ks])

        self.attns = nn.ModuleList([DotAttention(filters) for _ in range(num_classes)])

    def generate_pad_mask(self, batch_size, max_len, caption_length):
        total_len = max_len*len(self.Ks)
        for K in self.Ks:
            total_len -= (K-1)
        mask = torch.full((batch_size, total_len), fill_value=float('-inf'), dtype=torch.float, device='cuda')
        for ind1, cap_len in enumerate(caption_length):
            for ind2, K in enumerate(self.Ks):
                mask[ind1][max_len*ind2:cap_len-(K-1)] = 0

        return mask

    def forward(self, encoded_captions, caption_length):
        x = self.embed(encoded_captions).transpose(1, 2)

        batch_size = encoded_captions.size(0)
        max_len = encoded_captions.size(1)
        padding_mask = self.generate_pad_mask(batch_size, max_len, caption_length)

        output = [F.relu(conv(x)).transpose(1, 2) for conv in self.convs]
        output = torch.cat(output, dim=1)


        y_hats = [attn(output, padding_mask) for attn in self.attns]
        y_hats = torch.stack(y_hats, dim=1)

        return y_hats

In [None]:


class ST(nn.Module): # Show and Tell
    def __init__(self, image_encoder, num_tokens, fc_features=1024, embed_dim=256, hidden_size=512, dropout=0.1, freeze_encoder=True):
        super().__init__()
        self.embed = nn.Embedding(num_tokens, embed_dim)
        
        self.image_encoder = image_encoder
        if freeze_encoder: # The orginal paper freeze the densenet which is pretrained on ImageNet. Suprisingly, the results were very good
            for param in self.image_encoder.parameters():
                param.requires_grad = False
        
        self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, batch_first=True)
        
        self.fc1 = nn.Linear(fc_features, embed_dim)
        self.fc2 = nn.Linear(hidden_size, num_tokens)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, image, caption=None, caption_length=None, bos_id=1, eos_id=2, pad_id=3, max_len=300):
        if caption != None:
            avg_features, wxh_features = self.image_encoder(image) # (B,F), (B,F,W,H)
            
            img_features = self.fc1(avg_features) # (B,F) --> (B,E)
            cap_embed = self.embed(caption) # (B,L,E)
            embed = torch.cat([img_features.unsqueeze(1), cap_embed], dim=1) # (B,1+L,E)
            
            output, _ = self.rnn(embed) # (B,1+L,H)
            
            preds = self.fc2(self.dropout(output)) # (B,1+L,S)
            preds = torch.softmax(preds, dim = -1) # (B,1+L,S)
            return preds[:,1:,:] # (B,L,S)
        
        else:
            avg_features, wxh_features = self.image_encoder(image) # (B,F), (B,F,W,H)
            
            img_features = self.fc1(avg_features) # (B,F) --> (B,E)
            caption = torch.ones((img_features.shape[0],1), dtype=torch.long).to(img_features.device) * bos_id # (B,1)
            
            for i in range(max_len):
                cap_embed = self.embed(caption) # (B,L',E)
                embed = torch.cat([img_features.unsqueeze(1), cap_embed], dim=1) # (B,1+L',E)
                
                output, _ = self.rnn(embed) # (B,1+L',H)
                
                preds = self.fc2(self.dropout(output)) # (B,1+L',S)
                preds = torch.softmax(preds, dim = -1) # (B,1+L',S)
                preds = torch.argmax(preds[:,-1,:], dim=-1, keepdim=True) # (B,1)
                caption = torch.cat([caption, preds], dim=-1) # (B,L'+1)
            
            return caption # (B,L')
        
class SAT(nn.Module): # Show, Attend and Tell
    def __init__(self, image_encoder, num_tokens, fc_features=1024, embed_dim=256, hidden_size=512, dropout=0.1, freeze_encoder=True):
        super().__init__()
        self.embed = nn.Embedding(num_tokens, embed_dim)
        
        self.image_encoder = image_encoder
        if freeze_encoder: # The orginal paper freeze the densenet which is pretrained on ImageNet. Suprisingly, the results were very good
            for param in self.image_encoder.parameters():
                param.requires_grad = False
        
        self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, batch_first=True)
        
        self.fc1 = nn.Linear(fc_features, embed_dim)
        self.fc2 = nn.Linear(hidden_size, num_tokens)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, image, caption=None, caption_length=None, bos_id=1, eos_id=2, pad_id=3, max_len=300):
        if caption != None:
            avg_features, wxh_features = self.image_encoder(image) # (B,F), (B,F,W,H)
            
            img_features = self.fc1(avg_features) # (B,F) --> (B,E)
            cap_embed = self.embed(caption) # (B,L,E)
            embed = torch.cat([img_features.unsqueeze(1), cap_embed], dim=1) # (B,1+L,E)
            
            output, _ = self.rnn(embed) # (B,1+L,H)
            
            preds = self.fc2(self.dropout(output)) # (B,1+L,S)
            preds = torch.softmax(preds, dim = -1) # (B,1+L,S)
            return preds[:,1:,:] # (B,L,S)
        
        else:
            avg_features, wxh_features = self.image_encoder(image) # (B,F), (B,F,W,H)
            
            img_features = self.fc1(avg_features) # (B,F) --> (B,E)
            caption = torch.ones((img_features.shape[0],1), dtype=torch.long).to(img_features.device) * bos_id # (B,1)
            
            for i in range(max_len):
                cap_embed = self.embed(caption) # (B,L',E)
                embed = torch.cat([img_features.unsqueeze(1), cap_embed], dim=1) # (B,1+L',E)
                
                output, _ = self.rnn(embed) # (B,1+L',H)
                
                preds = self.fc2(self.dropout(output)) # (B,1+L',S)
                preds = torch.softmax(preds, dim = -1) # (B,1+L',S)
                preds = torch.argmax(preds[:,-1,:], dim=-1, keepdim=True) # (B,1)
                caption = torch.cat([caption, preds], dim=-1) # (B,L'+1)
            
            return caption # (B,L')

In [None]:

# --- Transformer Modules ---
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout)
        self.normalize = nn.LayerNorm(embed_dim)

    def forward(self, input, query, pad_mask=None, att_mask=None):
        input = input.permute(1,0,2) # (V,B,E)
        query = query.permute(1,0,2) # (Q,B,E)
        embed, att = self.attention(query, input, input, key_padding_mask=pad_mask, attn_mask=att_mask) # (Q,B,E), (B,Q,V)
        
        embed = self.normalize(embed + query) # (Q,B,E)
        embed = embed.permute(1,0,2) # (B,Q,E)
        return embed, att # (B,Q,E), (B,Q,V)
    
class PointwiseFeedForward(nn.Module):
    def __init__(self, emb_dim, fwd_dim, dropout=0.0):
        super().__init__()
        self.fwd_layer = nn.Sequential(
            nn.Linear(emb_dim, fwd_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fwd_dim, emb_dim),
        )
        self.normalize = nn.LayerNorm(emb_dim)

    def forward(self, input):
        output = self.fwd_layer(input) # (B,L,E)
        output = self.normalize(output + input) # (B,L,E)
        return output

class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, fwd_dim, dropout=0.0):
        super().__init__()
        self.attention = MultiheadAttention(embed_dim, num_heads, dropout)
        self.fwd_layer = PointwiseFeedForward(embed_dim, fwd_dim, dropout)

    def forward(self, input, pad_mask=None, att_mask=None):
        emb, att = self.attention(input,input,pad_mask,att_mask)
        emb = self.fwd_layer(emb)
        return emb, att

class TNN(nn.Module):
    def __init__(self, embed_dim, num_heads, fwd_dim, dropout=0.1, num_layers=1,
                num_tokens=1, num_posits=1, token_embedding=None, posit_embedding=None):
        super().__init__()
        self.token_embedding = nn.Embedding(num_tokens, embed_dim) if not token_embedding else token_embedding
        self.posit_embedding = nn.Embedding(num_posits, embed_dim) if not posit_embedding else posit_embedding
        self.transform = nn.ModuleList([TransformerLayer(embed_dim, num_heads, fwd_dim, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, token_index=None, token_embed=None, pad_mask=None, pad_id=-1, att_mask=None):
        if token_index != None:
            if pad_mask == None:
                pad_mask = (token_index == pad_id) # (B,L)
            posit_index = torch.arange(token_index.shape[1]).unsqueeze(0).repeat(token_index.shape[0],1).to(token_index.device) # (B,L)
            posit_embed = self.posit_embedding(posit_index) # (B,L,E)
            token_embed = self.token_embedding(token_index) # (B,L,E)
            final_embed = self.dropout(token_embed + posit_embed) # (B,L,E)
        elif token_embed != None:
            posit_index = torch.arange(token_embed.shape[1]).unsqueeze(0).repeat(token_embed.shape[0],1).to(token_embed.device) # (B,L)
            posit_embed = self.posit_embedding(posit_index) # (B,L,E)
            final_embed = self.dropout(token_embed + posit_embed) # (B,L,E)
        else:
            raise ValueError('token_index or token_embed must not be None')

        for i in range(len(self.transform)):
            final_embed = self.transform[i](final_embed, pad_mask, att_mask)[0]
            
        return final_embed # (B,L,E)

# --- Convolution Modules ---
class CNN(nn.Module):
    def __init__(self, model, model_type='resnet'):
        super().__init__()
        if 'res' in model_type.lower(): # resnet, resnet-50, resnest-50, ...
            modules = list(model.children())[:-1] # Drop the FC layer
            self.feature = nn.Sequential(*modules[:-1])
            self.average = modules[-1]
        elif 'dense' in model_type.lower(): # densenet, densenet-121, densenet121, ...
            modules = list(model.features.children())[:-1]
            self.feature = nn.Sequential(*modules)
            self.average = nn.AdaptiveAvgPool2d((1, 1))
        else:
            raise ValueError('Unsupported model_type!')
        
    def forward(self, input):
        wxh_features = self.feature(input) # (B,2048,W,H)
        avg_features = self.average(wxh_features) # (B,2048,1,1)
        avg_features = avg_features.view(avg_features.shape[0], -1) # (B,2048)
        return avg_features, wxh_features

class MVCNN(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        img = input[0] # (B,V,C,W,H)
        pos = input[1] # (B,V)
        B,V,C,W,H = img.shape

        img = img.view(B*V,C,W,H)
        avg, wxh = self.model(img) # (B*V,F), (B*V,F,W,H)
        avg = avg.view(B,V,-1) # (B,V,F)
        wxh = wxh.view(B,V,wxh.shape[-3],wxh.shape[-2],wxh.shape[-1]) # (B,V,F,W,H)
        
        msk = (pos == -1) # (B,V)
        msk_wxh = msk.view(B,V,1,1,1).float() # (B,V,1,1,1) * (B,V,F,C,W,H)
        msk_avg = msk.view(B,V,1).float() # (B,V,1) * (B,V,F)
        wxh = msk_wxh * (-1) + (1-msk_wxh) * wxh
        avg = msk_avg * (-1) + (1-msk_avg) * avg

        wxh_features = wxh.max(dim=1)[0] # (B,F,W,H)
        avg_features = avg.max(dim=1)[0] # (B,F)
        return avg_features, wxh_features

# --- Main Moduldes ---
class Classifier(nn.Module):
    def __init__(self, num_topics, num_states, cnn=None, tnn=None,
                fc_features=2048, embed_dim=128, num_heads=1, dropout=0.1):
        super().__init__()
        
        # For img & txt embedding and feature extraction
        self.cnn = cnn
        self.tnn = tnn
        self.img_features = nn.Linear(fc_features, num_topics * embed_dim) if cnn != None else None
        self.txt_features = MultiheadAttention(embed_dim, num_heads, dropout) if tnn != None else None
        
        # For classification
        self.topic_embedding = nn.Embedding(num_topics, embed_dim)
        self.state_embedding = nn.Embedding(num_states, embed_dim)
        self.attention = MultiheadAttention(embed_dim, num_heads)
        
        # Some constants
        self.num_topics = num_topics
        self.num_states = num_states
        self.dropout = nn.Dropout(dropout)
        self.normalize = nn.LayerNorm(embed_dim)

    def forward(self, img=None, txt=None, lbl=None, txt_embed=None, pad_mask=None, pad_id=3, threshold=0.5, get_embed=False, get_txt_att=False):
        # --- Get img and txt features ---
        if img != None: # (B,C,W,H) or ((B,V,C,W,H), (B,V))
            img_features, wxh_features = self.cnn(img) # (B,F), (B,F,W,H)
            img_features = self.dropout(img_features) # (B,F)
            
        if txt != None:
            if pad_id >= 0 and pad_mask == None:
                pad_mask = (txt == pad_id)
            txt_features = self.tnn(token_index=txt, pad_mask=pad_mask) # (B,L,E)
        
        elif txt_embed != None:
            txt_features = self.tnn(token_embed=txt_embed, pad_mask=pad_mask) # (B,L,E)

        # --- Fuse img and txt features ---
        if img != None and (txt != None or txt_embed != None):
            topic_index = torch.arange(self.num_topics).unsqueeze(0).repeat(img_features.shape[0],1).to(img_features.device) # (B,T)
            state_index = torch.arange(self.num_states).unsqueeze(0).repeat(img_features.shape[0],1).to(img_features.device) # (B,C)
            topic_embed = self.topic_embedding(topic_index) # (B,T,E)
            state_embed = self.state_embedding(state_index) # (B,C,E)
            
            img_features = self.img_features(img_features).view(img_features.shape[0], self.num_topics, -1) # (B,F) --> (B,T*E) --> (B,T,E)   
            txt_features, txt_attention = self.txt_features(txt_features, topic_embed, pad_mask) # (B,T,E), (B,T,L)
            final_embed = self.normalize(img_features + txt_features) # (B,T,E)
            
        elif img != None:
            topic_index = torch.arange(self.num_topics).unsqueeze(0).repeat(img_features.shape[0],1).to(img_features.device) # (B,T)
            state_index = torch.arange(self.num_states).unsqueeze(0).repeat(img_features.shape[0],1).to(img_features.device) # (B,C)
            topic_embed = self.topic_embedding(topic_index) # (B,T,E)
            state_embed = self.state_embedding(state_index) # (B,C,E)

            img_features = self.img_features(img_features).view(img_features.shape[0], self.num_topics, -1) # (B,F) --> (B,T*E) --> (B,T,E)   
            final_embed = img_features # (B,T,E)
            
        elif txt != None or txt_embed != None:
            topic_index = torch.arange(self.num_topics).unsqueeze(0).repeat(txt_features.shape[0],1).to(txt_features.device) # (B,T)
            state_index = torch.arange(self.num_states).unsqueeze(0).repeat(txt_features.shape[0],1).to(txt_features.device) # (B,C)
            topic_embed = self.topic_embedding(topic_index) # (B,T,E)
            state_embed = self.state_embedding(state_index) # (B,C,E)

            txt_features, txt_attention = self.txt_features(txt_features, topic_embed, pad_mask) # (B,T,E), (B,T,L)
            final_embed = txt_features # (B,T,E)
            
        else:
            raise ValueError('img and (txt or txt_embed) must not be all none')
        
        # Classifier output
        emb, att = self.attention(state_embed, final_embed) # (B,T,E), (B,T,C)
        
        if lbl != None: # Teacher forcing
            emb = self.state_embedding(lbl) # (B,T,E)
        else:
            emb = self.state_embedding((att[:,:,1] > threshold).long()) # (B,T,E)
            
        if get_embed:
            return att, final_embed + emb # (B,T,C), (B,T,E)
        elif get_txt_att and (txt != None or txt_embed != None):
            return att, txt_attention # (B,T,C), (B,T,L)
        else:
            return att # (B,T,C)

class Generator(nn.Module):
    def __init__(self, num_tokens, num_posits, embed_dim=128, num_heads=1, fwd_dim=256, dropout=0.1, num_layers=12):
        super().__init__()
        self.token_embedding = nn.Embedding(num_tokens, embed_dim)
        self.posit_embedding = nn.Embedding(num_posits, embed_dim)
        self.transform = nn.ModuleList([TransformerLayer(embed_dim, num_heads, fwd_dim, dropout) for _ in range(num_layers)])
        self.attention = MultiheadAttention(embed_dim, num_heads)
        self.num_tokens = num_tokens
        self.num_posits = num_posits
        
    def forward(self, source_embed, token_index=None, source_pad_mask=None, target_pad_mask=None, max_len=300, top_k=1, bos_id=1, pad_id=3, mode='eye'):
        if token_index != None: # --- Training/Testing Phase ---
            # Adding token embedding and posititional embedding.
            posit_index = torch.arange(token_index.shape[1]).unsqueeze(0).repeat(token_index.shape[0],1).to(token_index.device) # (1,L) --> (B,L)
            posit_embed = self.posit_embedding(posit_index) # (B,L,E)
            token_embed = self.token_embedding(token_index) # (B,L,E)
            target_embed = token_embed + posit_embed # (B,L,E)
            
            # Make embedding, attention mask, pad mask for Transformer Decoder
            final_embed = torch.cat([source_embed,target_embed], dim=1) # (B,T+L,E)
            if source_pad_mask == None:
                source_pad_mask = torch.zeros((source_embed.shape[0],source_embed.shape[1]),device=source_embed.device).bool() # (B,T)
            if target_pad_mask == None:
                target_pad_mask = torch.zeros((target_embed.shape[0],target_embed.shape[1]),device=target_embed.device).bool() # (B,L)
            pad_mask = torch.cat([source_pad_mask,target_pad_mask], dim=1) # (B,T+L)
            att_mask = self.generate_square_subsequent_mask_with_source(source_embed.shape[1], target_embed.shape[1], mode).to(final_embed.device) # (T+L,T+L)

            # Transformer Decoder
            for i in range(len(self.transform)):
                final_embed = self.transform[i](final_embed,pad_mask,att_mask)[0]

            # Make prediction for next tokens
            token_index = torch.arange(self.num_tokens).unsqueeze(0).repeat(token_index.shape[0],1).to(token_index.device) # (1,K) --> (B,K)
            token_embed = self.token_embedding(token_index) # (B,K,E)
            emb, att = self.attention(token_embed,final_embed) # (B,T+L,E), (B,T+L,K)
            
            # Truncate results from source_embed
            emb = emb[:,source_embed.shape[1]:,:] # (B,L,E)
            att = att[:,source_embed.shape[1]:,:] # (B,L,K)
            return att, emb
        
        else: # --- Inference Phase ---
            return self.infer(source_embed, source_pad_mask, max_len, top_k, bos_id, pad_id)

    def infer(self, source_embed, source_pad_mask=None, max_len=100, top_k=1, bos_id=1, pad_id=3):
        outputs = torch.ones((top_k, source_embed.shape[0], 1), dtype=torch.long).to(source_embed.device) * bos_id # (K,B,1) <s>
        scores = torch.zeros((top_k, source_embed.shape[0]), dtype=torch.float32).to(source_embed.device) # (K,B)

        for _ in range(1,max_len):
            possible_outputs = []
            possible_scores = []

            for k in range(top_k):
                output = outputs[k] # (B,L)
                score = scores[k] # (B)
                
                att, emb = self.forward(source_embed, output, source_pad_mask=source_pad_mask, target_pad_mask=(output == pad_id))
                val, idx = torch.topk(att[:,-1,:], top_k) # (B,K)
                log_val = -torch.log(val) # (B,K)
                
                for i in range(top_k):
                    new_output = torch.cat([output, idx[:,i].view(-1,1)], dim=-1) # (B,L+1)
                    new_score = score + log_val[:,i].view(-1) # (B)
                    possible_outputs.append(new_output.unsqueeze(0)) # (1,B,L+1)
                    possible_scores.append(new_score.unsqueeze(0)) # (1,B)
            
            possible_outputs = torch.cat(possible_outputs, dim=0) # (K^2,B,L+1)
            possible_scores = torch.cat(possible_scores, dim=0) # (K^2,B)

            # Pruning the solutions
            val, idx = torch.topk(possible_scores, top_k, dim=0) # (K,B)
            col_idx = torch.arange(idx.shape[1], device=idx.device).unsqueeze(0).repeat(idx.shape[0],1) # (K,B)
            outputs = possible_outputs[idx,col_idx] # (K,B,L+1)
            scores = possible_scores[idx,col_idx] # (K,B)

        val, idx = torch.topk(scores, 1, dim=0) # (1,B)
        col_idx = torch.arange(idx.shape[1], device=idx.device).unsqueeze(0).repeat(idx.shape[0],1) # (K,B)
        output = outputs[idx,col_idx] # (1,B,L)
        score = scores[idx,col_idx] # (1,B)
        return output.squeeze(0) # (B,L)

    def generate_square_subsequent_mask_with_source(self, src_sz, tgt_sz, mode='eye'):
        mask = self.generate_square_subsequent_mask(src_sz + tgt_sz)
        if mode == 'one': # model can look at surrounding positions of the current index ith
            mask[:src_sz, :src_sz] = self.generate_square_mask(src_sz)
        elif mode == 'eye': # model can only look at the current index ith
            mask[:src_sz, :src_sz] = self.generate_square_identity_mask(src_sz)
        else: # model can look at surrounding positions of the current index ith with some patterns
            raise ValueError('Mode must be "one" or "eye".')
        mask[src_sz:, src_sz:] = self.generate_square_subsequent_mask(tgt_sz)
        return mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def generate_square_identity_mask(self, sz):
        mask = (torch.eye(sz) == 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask 

    def generate_square_mask(self, sz):
        mask = (torch.ones(sz,sz) == 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

# --- Full Models ---
class ClsGen(nn.Module):
    def __init__(self, classifier, generator, num_topics, embed_dim):
        super().__init__()
        self.classifier = classifier
        self.generator = generator
        self.label_embedding = nn.Embedding(num_topics, embed_dim)

    def forward(self, image, history=None, caption=None, label=None, threshold=0.15, bos_id=1, eos_id=2, pad_id=3, max_len=300, get_emb=False):
        label = label.long() if label != None else label
        img_mlc, img_emb = self.classifier(img=image, txt=history, lbl=label, threshold=threshold, pad_id=pad_id, get_embed=True) # (B,T,C), (B,T,E)
        lbl_idx = torch.arange(img_emb.shape[1]).unsqueeze(0).repeat(img_emb.shape[0],1).to(img_emb.device) # (B,T)
        lbl_emb = self.label_embedding(lbl_idx) # (B,T,E)
        
        if caption != None:
            src_emb = img_emb + lbl_emb
            pad_mask = (caption == pad_id)
            cap_gen, cap_emb = self.generator(source_embed=src_emb, token_index=caption, target_pad_mask=pad_mask) # (B,L,S), (B,L,E)
            if get_emb:
                return cap_gen, img_mlc, cap_emb
            else:
                return cap_gen, img_mlc
        else:
            src_emb = img_emb + lbl_emb
            cap_gen = self.generator(source_embed=src_emb, token_index=caption, max_len=max_len, bos_id=bos_id, pad_id=pad_id) # (B,L,S)
            return cap_gen, img_mlc

class ClsGenInt(nn.Module):
    def __init__(self, clsgen, interpreter, freeze_evaluator=True):
        super().__init__()
        self.clsgen = clsgen
        self.interpreter = interpreter
            
        # Freeze evaluator's paramters
        if freeze_evaluator:
            for param in self.interpreter.parameters():
                param.requires_grad = False

    def forward(self, image, history=None, caption=None, label=None, threshold=0.15, bos_id=1, eos_id=2, pad_id=3, max_len=300):        
        if caption != None:
            pad_mask = (caption == pad_id)
            cap_gen, img_mlc, cap_emb = self.clsgen(image, history, caption, label, threshold, bos_id, eos_id, pad_id, max_len, True)
            cap_mlc = self.interpreter(txt_embed=cap_emb, pad_mask=pad_mask)
            return cap_gen, img_mlc, cap_mlc
        else:
            return self.clsgen(image, history, caption, label, threshold, bos_id, eos_id, pad_id, max_len, False)

In [None]:
# --- Helper Functions ---
def find_optimal_cutoff(target, predicted):
    fpr, tpr, threshold = metrics.roc_curve(target, predicted)
    gmeans = np.sqrt(tpr * (1-fpr))
    ix = np.argmax(gmeans)
    return threshold[ix]

def infer(data_loader, model, device='cpu', threshold=None):
    model.eval()
    outputs = []
    targets = []

    with torch.no_grad():
        prog_bar = tqdm(data_loader)
        for i, (source, target) in enumerate(prog_bar):
            source = data_to_device(source, device)
            target = data_to_device(target, device)

            # Use single input if there is no clinical history
            if threshold != None:
                output = model(image=source[0], history=source[3], threshold=threshold)
                # output = model(image=source[0], threshold=threshold)
                # output = model(image=source[0], history=source[3], label=source[2])
                # output = model(image=source[0], label=source[2])
            else:
                # output = model(source[0], source[1])
                output = model(source[0])
                
            outputs.append(data_to_device(output))
            targets.append(data_to_device(target))

        outputs = data_concatenate(outputs)
        targets = data_concatenate(targets)
    
    return outputs, targets

In [None]:
def load(path, model, optimizer=None, scheduler=None):
	checkpoint = torch.load(path,map_location=torch.device('cpu'))
	print(checkpoint.keys())
	# --- Model Statistics ---
	epoch = checkpoint['epoch']
	stats = checkpoint['stats']
	# --- Model Parameters ---
	model.load_state_dict(checkpoint['model_state_dict'])
	if optimizer != None:
		try:
			optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
		except: # Input optimizer doesn't fit the checkpoint one --> should be ignored
			print('Cannot load the optimizer')
	if scheduler != None:
		try:
			scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
		except: # Input scheduler doesn't fit the checkpoint one --> should be ignored
			print('Cannot load the scheduler')
	return epoch, stats

In [None]:

# --- Hyperparameters ---
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OMP_NUM_THREADS"] = "1"
torch.set_num_threads(1)
torch.manual_seed(seed=123)

RELOAD = True # True / False
PHASE = 'INFER' # TRAIN / TEST / INFER
DATASET_NAME = 'NLMCXR' # NIHCXR / NLMCXR / MIMIC 
BACKBONE_NAME = 'DenseNet121' # ResNeSt50 / ResNet50 / DenseNet121
MODEL_NAME = 'ClsGenInt' # ClsGen / ClsGenInt / VisualTransformer / GumbelTransformer

if DATASET_NAME == 'MIMIC':
    EPOCHS = 50 # Start overfitting after 20 epochs
    BATCH_SIZE = 8 if PHASE == 'TRAIN' else 64 # 128 # Fit 4 GPUs
    MILESTONES = [25] # Reduce LR by 10 after reaching milestone epochs
    
elif DATASET_NAME == 'NLMCXR':
    EPOCHS = 50 # Start overfitting after 20 epochs
    BATCH_SIZE = 8 if PHASE == 'TRAIN' else 16 # Fit 4 GPUs
    MILESTONES = [25] # Reduce LR by 10 after reaching milestone epochs
    
else:
    raise ValueError('Invalid DATASET_NAME')

if __name__ == "__main__":
    # --- Choose Inputs/Outputs
    if MODEL_NAME in ['ClsGen', 'ClsGenInt']:
        SOURCES = ['image','caption','label','history']
        TARGETS = ['caption','label']
        KW_SRC = ['image','caption','label','history']
        KW_TGT = None
        KW_OUT = None
                
    elif MODEL_NAME == 'VisualTransformer':
        SOURCES = ['image','caption']
        TARGETS = ['caption']# ,'label']
        KW_SRC = ['image','caption'] # kwargs of Classifier
        KW_TGT = None
        KW_OUT = None
        
    elif MODEL_NAME == 'GumbelTransformer':
        SOURCES = ['image','caption','caption_length']
        TARGETS = ['caption','label']
        KW_SRC = ['image','caption','caption_length'] # kwargs of Classifier
        KW_TGT = None
        KW_OUT = None
        
    else:
        raise ValueError('Invalid BACKBONE_NAME')
        
    # --- Choose a Dataset ---
    if DATASET_NAME == 'NLMCXR':
        INPUT_SIZE = (256,256)
        MAX_VIEWS = 2
        NUM_LABELS = 114
        NUM_CLASSES = 2

        dataset = NLMCXR('/kaggle/input/', INPUT_SIZE, view_pos=['AP','PA','LATERAL'], max_views=MAX_VIEWS, sources=SOURCES, targets=TARGETS)
        train_data, val_data, test_data = dataset.get_subsets(seed=123)
        
        VOCAB_SIZE = len(dataset.vocab)
        POSIT_SIZE = dataset.max_len
        COMMENT = 'MaxView{}_NumLabel{}_{}History'.format(MAX_VIEWS, NUM_LABELS, 'No' if 'history' not in SOURCES else '')
        
    else:
        raise ValueError('Invalid DATASET_NAME')

    # --- Choose a Backbone --- 
    if BACKBONE_NAME == 'ResNeSt50':
        torch.hub.list('zhanghang1989/ResNeSt', force_reload=True)
        backbone = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        FC_FEATURES = 2048
        
    elif BACKBONE_NAME == 'ResNet50':
        backbone = models.resnet50(pretrained=True)
        FC_FEATURES = 2048
        
    elif BACKBONE_NAME == 'DenseNet121':
        backbone = torch.hub.load('pytorch/vision:v0.5.0', 'densenet121', pretrained=True)
        FC_FEATURES = 1024
        
    else:
        raise ValueError('Invalid BACKBONE_NAME')

    # --- Choose a Model ---
    if MODEL_NAME == 'ClsGen':
        LR = 3e-4 # Fastest LR
        WD = 1e-2 # Avoid overfitting with L2 regularization
        DROPOUT = 0.1 # Avoid overfitting
        NUM_EMBEDS = 256
        FWD_DIM = 256
        
        NUM_HEADS = 8
        NUM_LAYERS = 1
        
        cnn = CNN(backbone, BACKBONE_NAME)
        cnn = MVCNN(cnn)
        tnn = TNN(embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE)
        
        # Not enough memory to run 8 heads and 12 layers, instead 1 head is enough
        NUM_HEADS = 1
        NUM_LAYERS = 12
        
        cls_model = Classifier(num_topics=NUM_LABELS, num_states=NUM_CLASSES, cnn=cnn, tnn=tnn, fc_features=FC_FEATURES, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, dropout=DROPOUT)
        gen_model = Generator(num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS)
        
        model = ClsGen(cls_model, gen_model, NUM_LABELS, NUM_EMBEDS)
        criterion = CELossTotal(ignore_index=3)
        
    elif MODEL_NAME == 'ClsGenInt':
        LR = 3e-5 # Slower LR to fine-tune the model (Open-I)
        # LR = 3e-6 # Slower LR to fine-tune the model (MIMIC)
        WD = 1e-2 # Avoid overfitting with L2 regularization
        DROPOUT = 0.1 # Avoid overfitting
        NUM_EMBEDS = 256
        FWD_DIM = 256
        
        NUM_HEADS = 8
        NUM_LAYERS = 1
        
        cnn = CNN(backbone, BACKBONE_NAME)
        cnn = MVCNN(cnn)
        tnn = TNN(embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE)
        
        # Not enough memory to run 8 heads and 12 layers, instead 1 head is enough
        NUM_HEADS = 1
        NUM_LAYERS = 12
        
        cls_model = Classifier(num_topics=NUM_LABELS, num_states=NUM_CLASSES, cnn=cnn, tnn=tnn, fc_features=FC_FEATURES, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, dropout=DROPOUT)
        gen_model = Generator(num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS)
        
        clsgen_model = ClsGen(cls_model, gen_model, NUM_LABELS, NUM_EMBEDS)
        clsgen_model = nn.DataParallel(clsgen_model)
        
        if not RELOAD:
            checkpoint_path_from = 'checkpoints/{}_ClsGen_{}_{}.pt'.format(DATASET_NAME, BACKBONE_NAME, COMMENT)
            last_epoch, (best_metric, test_metric) = load(checkpoint_path_from, clsgen_model)
            print('Reload From: {} | Last Epoch: {} | Validation Metric: {} | Test Metric: {}'.format(checkpoint_path_from, last_epoch, best_metric, test_metric))
        
        # Initialize the Interpreter module
        NUM_HEADS = 8
        NUM_LAYERS = 1
        
        tnn = TNN(embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE)
        int_model = Classifier(num_topics=NUM_LABELS, num_states=NUM_CLASSES, cnn=None, tnn=tnn, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, dropout=DROPOUT)
        int_model = nn.DataParallel(int_model)
        
        if not RELOAD:
            checkpoint_path_from = 'checkpoints/{}_Transformer_MaxView2_NumLabel{}.pt'.format(DATASET_NAME, NUM_LABELS)
            last_epoch, (best_metric, test_metric) = load(checkpoint_path_from, int_model)
            print('Reload From: {} | Last Epoch: {} | Validation Metric: {} | Test Metric: {}'.format(checkpoint_path_from, last_epoch, best_metric, test_metric))
        
        model = ClsGenInt(clsgen_model.module.cpu(), int_model.module.cpu(), freeze_evaluator=True)
        criterion = CELossTotalEval(ignore_index=3)
        
    elif MODEL_NAME == 'VisualTransformer':
        # Clinical Coherent X-ray Report (Justin et. al.) - No Fine-tune
        LR = 5e-5
        WD = 1e-2 # Avoid overfitting with L2 regularization
        DROPOUT = 0.1 # Avoid overfitting
        NUM_EMBEDS = 256
        NUM_HEADS = 8
        FWD_DIM = 4096
        NUM_LAYERS_ENC = 1
        NUM_LAYERS_DEC = 6
        
        cnn = CNN(backbone, BACKBONE_NAME)
        cnn = MVCNN(cnn)
        model = Transformer(image_encoder=cnn, num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE, 
                            fc_features=FC_FEATURES, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, 
                            dropout=DROPOUT, num_layers_enc=NUM_LAYERS_ENC, num_layers_dec=NUM_LAYERS_DEC, freeze_encoder=True)
        criterion = CELossShift(ignore_index=3)
    else:
        raise ValueError('Invalid MODEL_NAME')
    
    # --- Main program ---
    train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)
    val_loader = data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    test_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

    model = nn.DataParallel(model)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WD)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES)

    print('Total Parameters:', sum(p.numel() for p in model.parameters()))
    
    last_epoch = -1
    best_metric = 1e9

    checkpoint_path_from = '/kaggle/input/nlm-torch-model/NLMCXR_ClsGenInt_DenseNet121_MaxView2_NumLabel114_History.pt' #'/kaggle/input/torch-models/{}_{}_{}_{}.pt'.format(DATASET_NAME,MODEL_NAME,BACKBONE_NAME,COMMENT)
    checkpoint_path_to = '/kaggle/working/checkpoints/{}_{}_{}_{}.pt'.format(DATASET_NAME,MODEL_NAME,BACKBONE_NAME,COMMENT)
    
    if RELOAD:
        last_epoch, (best_metric, test_metric) = load(checkpoint_path_from, model, optimizer, scheduler) # Reload
        # last_epoch, (best_metric, test_metric) = load(checkpoint_path_from, model) # Fine-tune
        print('Reload From: {} | Last Epoch: {} | Validation Metric: {} | Test Metric: {}'.format(checkpoint_path_from, last_epoch, best_metric, test_metric))

    if PHASE == 'TRAIN':
        scaler = torch.cuda.amp.GradScaler()
        
        for epoch in range(last_epoch+1, EPOCHS):
            val_loss = -1
            test_loss = -1
            print('Epoch:', epoch)
            save(checkpoint_path_to, model, optimizer, scheduler, epoch, (val_loss, test_loss))
            train_loss = train(train_loader, model, optimizer, criterion, device='cuda', kw_src=KW_SRC, kw_tgt=KW_TGT, kw_out=KW_OUT, scaler=scaler)
            val_loss = test(val_loader, model, criterion, device='cuda', kw_src=KW_SRC, kw_tgt=KW_TGT, kw_out=KW_OUT, return_results=False)
            test_loss = test(test_loader, model, criterion, device='cuda', kw_src=KW_SRC, kw_tgt=KW_TGT, kw_out=KW_OUT, return_results=False)
            
            scheduler.step()
            
            if best_metric > val_loss:
                best_metric = val_loss
                save(checkpoint_path_to, model, optimizer, scheduler, epoch, (val_loss, test_loss))
                print('New Best Metric: {}'.format(best_metric)) 
                print('Saved To:', checkpoint_path_to)
    elif PHASE == 'TEST':
        # Output the file list for inspection
        out_file_img = open('outputs/{}_{}_{}_{}_Img.txt'.format(DATASET_NAME, MODEL_NAME, BACKBONE_NAME, COMMENT), 'w')
        for i in range(len(test_data.idx_pidsid)):
            out_file_img.write(test_data.idx_pidsid[i][0] + ' ' + test_data.idx_pidsid[i][1] + '\n')
            
        
    elif PHASE == 'INFER':
        txt_test_outputs, txt_test_targets = infer(test_loader, model, device='cpu', threshold=0.25)
        gen_outputs = txt_test_outputs[0]
        gen_targets = txt_test_targets[0]
        
        out_file_ref = open('/kaggle/working/x_{}_{}_{}_{}_Ref.txt'.format(DATASET_NAME, MODEL_NAME, BACKBONE_NAME, COMMENT), 'w')
        out_file_hyp = open('/kaggle/working/x_{}_{}_{}_{}_Hyp.txt'.format(DATASET_NAME, MODEL_NAME, BACKBONE_NAME, COMMENT), 'w')
        out_file_lbl = open('/kaggle/working/x_{}_{}_{}_{}_Lbl.txt'.format(DATASET_NAME, MODEL_NAME, BACKBONE_NAME, COMMENT), 'w')
        
        for i in range(len(gen_outputs)):
            candidate = ''
            for j in range(len(gen_outputs[i])):
                tok = dataset.vocab.id_to_piece(int(gen_outputs[i,j]))
                if tok == '</s>':
                    break # Manually stop generating token after </s> is reached
                elif tok == '<s>':
                    continue
                elif tok == '▁': # space
                    if len(candidate) and candidate[-1] != ' ':
                        candidate += ' '
                elif tok in [',', '.', '-', ':']: # or not tok.isalpha():
                    if len(candidate) and candidate[-1] != ' ':
                        candidate += ' ' + tok + ' ' 
                    else:
                        candidate += tok + ' '
                else: # letter
                    candidate += tok       
            out_file_hyp.write(candidate + '\n')
            
            reference = ''
            for j in range(len(gen_targets[i])):
                tok = dataset.vocab.id_to_piece(int(gen_targets[i,j]))
                if tok == '</s>':
                    break
                elif tok == '<s>':
                    continue
                elif tok == '▁': # space
                    if len(reference) and reference[-1] != ' ':
                        reference += ' '
                elif tok in [',', '.', '-', ':']: # or not tok.isalpha():
                    if len(reference) and reference[-1] != ' ':
                        reference += ' ' + tok + ' ' 
                    else:
                        reference += tok + ' '
                else: # letter
                    reference += tok    
            out_file_ref.write(reference + '\n')

        for i in tqdm(range(len(test_data))):
            target = test_data[i][1] # caption, label
            out_file_lbl.write(' '.join(map(str,target[1])) + '\n')
                
    else:
        raise ValueError('Invalid PHASE')