In [1]:
import os
import pandas as pd
from PIL import Image
import numpy as np

from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import nn

import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
import torchvision.models as models
import torch.optim as optim
import torch.nn.functional as F

random_state = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:
    def __init__(self, freq_threshold):
        self.index_to_string = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.string_to_index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        
    def __len__(self):
        return len(self.index_to_string)
    
    @staticmethod
    def tokenizer(text):
        #returns tokenized sentence in a form of list with single words
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def return_index_to_string(self):
        return self.index_to_string
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for word in self.tokenizer(sentence_list):
            if word not in frequencies:
                frequencies[word] = 1
            else:
                frequencies[word] += 1
            if frequencies[word] == self.freq_threshold:
                self.string_to_index[word] = idx
                self.index_to_string[idx] = word
                idx += 1
                
    def numericalize(self, text):
        #it takes a sentence and returns indexes of words in it as a form of list
        tokenized_text = self.tokenizer(text)
        
        return [self.string_to_index[token] if token in self.string_to_index 
                else self.string_to_index["<UNK>"] for token in tokenized_text]

In [3]:
class MyCollate:
        def __init__(self, pad_idx):
            self.pad_idx = pad_idx
            
        def __call__(self, batch):
            imgs = [item[0].unsqueeze(0) for item in batch]
            imgs = torch.cat(imgs, dim=0)
            targets = [item[1] for item in batch]
            targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
            
            return imgs, targets    
        
#here previously batch_first was set to false, but i changed it to true

In [4]:
class XRayDataset(Dataset):
    def __init__(self, cvs_file, path, transform, freq_threshold, size=(624,512)):
        #path is for general folder, csv_file is csv file not path
        self.path = path
        self.dataframe = cvs_file
        self.size = size
        self.transform = transform
        self.freq_thresh = freq_threshold
        
        self.img_col = self.dataframe["Imgs_paths"]
        self.findings_col = self.dataframe["findings"]
        
        self.vocab = Vocabulary(self.freq_thresh)
        self.st = ""
        self.vocab.build_vocabulary(self.st.join(self.findings_col.tolist()[:]))
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index):
        finding = self.findings_col[index]
        img_id = self.img_col[index]
        img_path = self.path + "/Images/" + img_id + ".png"
        img = Image.open(img_path).resize(self.size)
        
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.string_to_index["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(finding)
        numericalized_caption.append(self.vocab.string_to_index["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [5]:
def get_loader(csv_file, path, transform, batch_size, freq_threshold, shuffle=True):
    dataset = XRayDataset(csv_file, path, transform, freq_threshold)
    
    pad_idx = dataset.vocab.string_to_index["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size = batch_size,
        shuffle = shuffle,
        collate_fn = MyCollate(pad_idx=pad_idx)        
    )

    return loader, dataset

In [6]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        
        inception = models.inception_v3(pretrained=True, aux_logits=False)
        for param in inception.parameters():
            param.requires_grad_(False)
        modules = list(inception.children())[:-4]
        self.inception = nn.Sequential(*modules).to(device)
        
    def forward(self, images):
        features = self.inception(images)
        #2048
        
        batch, feature_maps, size_1, size_2 = features.size()
        features = features.permute(0,2,3,1)
        features = features.view(batch, size_1*size_2, feature_maps) 
        
        return features

In [7]:
class BahdanauAttention(nn.Module):
    def __init__(self, num_features, hidden_dim, output_dim=1):
        super(BahdanauAttention, self).__init__()
        self.num_features = num_features #2048
        self.hidden_dim = hidden_dim #512
        self.output_dim = output_dim #1
        
        #this layer learns attention over features from encoder
        self.W_a = nn.Linear(self.num_features, self.hidden_dim).to(device)
        #this layers learns attention over previous decoder state
        self.U_a = nn.Linear(self.hidden_dim, self.hidden_dim).to(device)
        #this produces the output from two previous ones
        self.V_a = nn.Linear(self.hidden_dim, self.output_dim).to(device)
        
    def forward(self, features, decoder_hidden):
        #features are from encoder cnn
        #decoder_hidden is previous hidden state from decoder lstm
        
        decoder_hidden = decoder_hidden.unsqueeze(1) #to add time steps/batch
        Wa = self.W_a(features) #calculating attn over features from encoder
        Ua = self.U_a(decoder_hidden) #calculating attention over previous hidden state
        atten_tan = torch.tanh(Wa+Ua) #calculating tangent from added two previous tensors
        atten_score = self.V_a(atten_tan) #linear layer to calculate attention weights
        atten_weight = F.softmax(atten_score, dim=1) #activation to this layer
        
        context = torch.sum(atten_weight*features, dim=1) #multiplication of features and atten weights
        #print("Atten weight before squeeze",atten_weight.size())
        atten_weight = atten_weight.squeeze(dim=2)
        #print("Atten weight after squeeze dim=2", atten_weight.size())
        
        return context, atten_weight

In [8]:
class DecoderRNN(nn.Module):
    
    def __init__(self, num_features, embedding_dim, hidden_dim, vocab_size, p=0.5):
        super(DecoderRNN, self).__init__()
        
        self.num_features = num_features #2048
        self.embedding_dim = embedding_dim #256
        self.hidden_dim = hidden_dim #512
        self.vocab_size = vocab_size
        self.sample_temp = 0.5
        
        self.embeddings = nn.Embedding(vocab_size, embedding_dim).to(device) #embedding of caption's words
        self.lstm = nn.LSTMCell(embedding_dim + num_features, hidden_dim).to(device) #concat vector of embedded hidden state and 
        #context vector from attention
        self.fc = nn.Linear(hidden_dim, vocab_size).to(device)
        
        self.attention = BahdanauAttention(self.num_features, self.hidden_dim)
        self.dropout = nn.Dropout(p=p)
        #initialization to hidden state and cell memory
        self.init_h = nn.Linear(num_features, hidden_dim).to(device)
        self.init_c = nn.Linear(num_features, hidden_dim).to(device)
        
    def forward(self, captions, features, sample_prob = 0.0):
        
        embed = self.embeddings(captions) #should be of size batch_size, sequence_length, embedding_dimension
        #print("Size of embed: ", embed.size())
        h, c = self.init_hidden(features)
        #print("Size of h, c: ", h.size(), c.size())
        sequence_len = captions.size(1)
        feature_size = features.size(1)
        batch_size = features.size(0)
        #print("Size of features: ", features.size())
        outputs = torch.zeros(batch_size, sequence_len, self.vocab_size).to(device)
        atten_weights = torch.zeros(batch_size, sequence_len, feature_size).to(device)
        
        for t in range(sequence_len): #loop for each word in caption
            sample_prob = 0.0 if t == 0 else 0.5
            use_sampling = np.random.random() < sample_prob
            if use_sampling == False:
                word_embed = embed[:,t,:] #embedding from current time step
            context, atten_weight = self.attention(features, h)
            input_concat = torch.cat([word_embed, context], 1) #input to lstm
            h, c = self.lstm(input_concat, (h,c))
            h = self.dropout(h)
            output = self.fc(h)
            if use_sampling == True:
                scaled_output = output / self.sample_temp
                scoring = F.log_softmax(scaled_output, dim=1)
                top_idx = scoring.topk(1)[1]
                word_embed = self.embeddings(top_idx).squeeze(1) #word predicted previously
            outputs[:,t,:] = output
            atten_weights[:,t,:] = atten_weight
        return outputs, atten_weights
    
    def init_hidden(self, features):
        mean_annotations = torch.mean(features, dim=1)
        h0 = self.init_h(mean_annotations)
        c0 = self.init_c(mean_annotations)
        return h0, c0

In [9]:
transform = transforms.Compose(
    [
        transforms.Resize((299,299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [27]:
def save_checkopoint(state, file_path):
  print("Saving")
  torch.save(state, file_path)
  print("Saved")

def train(n_epochs, encoder, decoder, optimizer, criterion, batch_size, data_loader, valid_loader, checkpoint_path = None, save_model=True):

  loss_history = []
  valid_loss_history = []

  total_step = len(data_loader)//batch_size + 1
  total_valid_step = len(valid_loader)//batch_size + 1

  for epoch in range(n_epochs):
    epoch_loss = 0.0
    epoch_valid_loss = 0.0
    
    #training loop
    for i_step in range(1, total_step):
        encoder.eval()
        decoder.train()
        
        images, captions = next(iter(data_loader)) #this should obtain a whole batch of images and captions for them
        captions_target = captions[:,1:].to(device) #without start token
        captions_train = captions[:,:-1].to(device) #lstm must predict end token

        img = images.to(device)

        decoder.zero_grad()
        encoder.zero_grad()

        features = encoder(img)
        outputs, atten_weights = decoder(captions=captions_train, features = features)

        loss = criterion(outputs.view(-1, vocab_size), captions_target.reshape(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        stats = 'Epoch train: [%d/%d], Batch train: [%d/%d], Loss train: %.4f' % (epoch, n_epochs, i_step, total_step, loss.item())
        print('\r' + stats, end="")

      #validation loop
    for i_step_val in range(1, total_valid_step):
        decoder.eval()

        images, captions = next(iter(valid_loader))
        captions_target = captions[:,1:].to(device)
        captions_valid = captions[:,:-1].to(device)
        
        img = images.to(device)

        decoder.zero_grad()
        encoder.zero_grad()

        features = encoder(img)
        outputs, atten_weights = decoder(captions=captions_valid, features=features)

        loss_val = criterion(outputs.view(-1, vocab_size), captions_valid.reshape(-1))
        epoch_valid_loss += loss_val
            
    epoch_loss_avg = epoch_loss / total_step
    epoch_valid_loss_avg = epoch_valid_loss / total_valid_step
    loss_history.append(epoch_loss_avg)
    valid_loss_history.append(epoch_valid_loss_avg)
      
    print('\r')
    print('Epoch train:', epoch)
    print('\r' + 'Avg. Loss train: %.4f, Avg.' % (epoch_loss_avg), end="")
    print('\r')

    checkpoint = {'state_dict': decoder.state_dict(), 'optimizer': optimizer.state_dict()}
    if save_model:
      save_checkopoint(checkpoint, checkpoint_path)

  return loss_history, valid_loss_history

In [10]:
"""def train(epoch, encoder, decoder, optimizer, criterion, total_step, num_epochs, data_loader, write_file=None, save_every=None):
    #function for a single epoch
    epoch_loss = 0.0
    
    
    for i_step in range(1, total_step+1):
        encoder.eval()
        decoder.train()
        
        images, captions = next(iter(data_loader)) #this should obtain a whole batch of images and captions for them
        captions_target = captions[:,1:].to(device) #without start token
        captions_train = captions[:,:-1].to(device) #lstm must predict end token

        img = images.to(device)

        decoder.zero_grad()
        encoder.zero_grad()

        features = encoder(img)
        outputs, atten_weights = decoder(captions=captions_train, features = features)

        loss = criterion(outputs.view(-1, vocab_size), captions_target.reshape(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        stats = 'Epoch train: [%d/%d], Step train: [%d/%d], Loss train: %.4f' % (epoch, num_epochs, i_step, total_step, loss.item())
        print('\r' + stats, end="")
            
    epoch_loss_avg = epoch_loss / total_step
    
    print('\r')
    print('Epoch train:', epoch)
    print('\r' + 'Avg. Loss train: %.4f, Avg.' % (epoch_loss_avg), end="")
    print('\r')

In [11]:
path = '/content/drive/MyDrive/Data/XrayNLP'
csv_path = path + "/" + "dataframe.csv"
checkpoint_path = path + "/" + "attention1.pth.tar"

batch_size = 32
embed_size = 256
hidden_size = 512
num_features = 2048
n_epochs = 5
freq_threshold = 3
learning_rate = 0.001

df = pd.read_csv(csv_path)
csv_train, csv_validate, csv_test = np.split(df.sample(frac=1, random_state=random_state), [int(.8*len(df)), int(.9*len(df))])
csv_train, csv_validate, csv_test = csv_train.reset_index(), csv_validate.reset_index(), csv_test.reset_index()

_, dataset = get_loader(df, path, transform, batch_size, freq_threshold)
train_loader, train_dataset = get_loader(csv_train, path, transform, batch_size, freq_threshold)
valid_loader, valid_dataset = get_loader(csv_validate, path, transform, batch_size, freq_threshold)
test_loader, test_dataset = get_loader(csv_test, path, transform, batch_size, freq_threshold)

vocab_size = len(train_dataset.vocab)
#total_step = len(train_dataset)//batch_size + 1

encoder = EncoderCNN()
decoder = DecoderRNN(num_features, embed_size, hidden_size, vocab_size)
optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.vocab.string_to_index["<PAD>"])


Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


HBox(children=(FloatProgress(value=0.0, max=108949747.0), HTML(value='')))




In [28]:
loss_history, valid_loss_history = train(n_epochs, encoder, decoder, optimizer, criterion, batch_size, train_loader, valid_loader, checkpoint_path)

Epoch train: [0/5], Batch train: [5/6], Loss train: 4.5836
Epoch train: 0
Avg. Loss train: 4.0113, Avg.
Saving
Saved
Epoch train: [1/5], Batch train: [5/6], Loss train: 4.6062
Epoch train: 1
Avg. Loss train: 3.9158, Avg.
Saving
Saved
Epoch train: [2/5], Batch train: [5/6], Loss train: 4.5171
Epoch train: 2
Avg. Loss train: 3.9088, Avg.
Saving
Saved
Epoch train: [3/5], Batch train: [5/6], Loss train: 4.5243
Epoch train: 3
Avg. Loss train: 3.7007, Avg.
Saving
Saved
Epoch train: [4/5], Batch train: [5/6], Loss train: 4.2242
Epoch train: 4
Avg. Loss train: 3.6639, Avg.
Saving
Saved
