In [2]:
import os
import pandas as pd
from PIL import Image
import numpy as np

from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import nn

import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
import torchvision.models as models
import torch.optim as optim
import torch.nn.functional as F

random_state = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batches goes first in all places

In [11]:
spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:
    def __init__(self, freq_threshold):
        self.index_to_string = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.string_to_index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        
    def __len__(self):
        return len(self.index_to_string)
    
    @staticmethod
    def tokenizer(text):
        #returns tokenized sentence in a form of list with single words
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def return_index_to_string(self):
        return self.index_to_string
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for word in self.tokenizer(sentence_list):
            if word not in frequencies:
                frequencies[word] = 1
            else:
                frequencies[word] += 1
            if frequencies[word] == self.freq_threshold:
                self.string_to_index[word] = idx
                self.index_to_string[idx] = word
                idx += 1
                
    def numericalize(self, text):
        #it takes a sentence and returns indexes of words in it as a form of list
        tokenized_text = self.tokenizer(text)
        
        return [self.string_to_index[token] if token in self.string_to_index 
                else self.string_to_index["<UNK>"] for token in tokenized_text]

In [12]:
class MyCollate:
        def __init__(self, pad_idx):
            self.pad_idx = pad_idx
            
        def __call__(self, batch):
            imgs = [item[0].unsqueeze(0) for item in batch]
            imgs = torch.cat(imgs, dim=0)
            targets = [item[1] for item in batch]
            targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
            
            return imgs, targets    
        
#here previously batch_first was set to false, but i changed it to true

In [14]:
class XRayDataset(Dataset):
    def __init__(self, cvs_file, path, transform, freq_threshold, size=(624,512)):
        #path is for general folder, csv_file is csv file not path
        self.path = path
        self.dataframe = cvs_file
        self.size = size
        self.transform = transform
        
        self.img_col = self.dataframe["Imgs_paths"]
        self.findings_col = self.dataframe["findings"]
        
        self.vocab = Vocabulary(freq_thresh)
        self.st = ""
        self.vocab.build_vocabulary(self.st.join(self.findings_col.tolist()[:]))
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index):
        findings = self.findings_col[index]
        img_id = self.img_col[index]
        img_path = self.path + "/Images/" + img_id + ".png"
        img = Image.open(img_path).resize(self.size)
        
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.string_to_index["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(finding)
        numericalized_caption.append(self.vocab.string_to_index["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [15]:
def get_loader(csv_file, path, transform, batch_size, shuffle=True):
    dataset = XRayDataset(csv_file, path, transform)
    
    pad_idx = dataset.vocab.string_to_index["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size = batch_size,
        shuffle = shuffle,
        collate_fn = MyCollate(pad_idx=pad_idx)        
    )

In [16]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        
        inception = models.inception_v3(pretrained=True, aux_logits=False)
        for param in inception.parameters():
            param.requires_grad_(False)
        modules = list(inception.children())[:-4]
        self.inception = nn.Sequential(*modules)
        
    def forward(self, images):
        features = self.inception(images)
        #2048
        
        batch, feature_maps, size_1, size_2 = features.size()
        features = features.permute(0,2,3,1)
        features = features.view(batch, size_1*size_2, feature_maps) 
        
        return features

In [17]:
class BahdanauAttention(nn.Module):
    def __init__(self, num_features, hidden_dim, output_dim=1):
        super(BahdanauAttention, self).__init__()
        self.num_features = num_features #2048
        self.hidden_dim = hidden_dim #512
        self.output_dim = output_dim #1
        
        #this layer learns attention over features from encoder
        self.W_a = nn.Linear(self.num_features, self.hidden_dim)
        #this layers learns attention over previous decoder state
        self.U_a = nn.Linear(self.hidden_dim, self.hidden_dim)
        #this produces the output from two previous ones
        self.V_a = nn.Linear(self.hidden_dim, self.output_dim)
        
    def forward(self, features, decoder_hidden):
        #features are from encoder cnn
        #decoder_hidden is previous hidden state from decoder lstm
        
        decoder_hidden = decoder_hidden.unsqueeze(1) #to add time steps/batch
        Wa = self.W_a(features) #calculating attn over features from encoder
        Ua = self.U_a(decoder_hidden) #calculating attention over previous hidden state
        atten_tan = torch.tanh(Wa+Ua) #calculating tangent from added two previous tensors
        atten_score = self.V_a(atten_tan) #linear layer to calculate attention weights
        atten_weight = F.softmax(atten_score, dim=1) #activation to this layer
        
        context = torch.sum(atten_weight*features, dim=1) #multiplication of features and atten weights
        print("Atten weight before squeeze",atten_weight.size())
        atten_weight = atten_weight.squeeze(dim=2)
        print("Atten weight after squeeze dim=2", atten_weight.size())
        
        return context, atten_weight