In [24]:
import os
import xml.etree.ElementTree as ET
from collections import Counter
import pandas as pd
from PIL import Image
import numpy as np
from google.colab import output

from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch import nn

import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
import torchvision.models as models
import torch.optim as optim

random_state = 42

## Dataframe creation

In [None]:
# Data exploration
path = ''
imgs_path = path + '/Images'
reports_path = path + '/reports'
imgs_list = os.listdir(imgs_path)
rep_list = os.listdir(reports_path)

print("Number of images: ", len(imgs_list))
print("Number of reports: ", len(rep_list))

n_imgs_per_report = []
for rep in rep_list:
    tree = ET.parse(reports_path + "/" + rep)
    root = tree.getroot()
    images = root.findall(".//parentImage")
    n_images = len(images)
    n_imgs_per_report.append(n_images)
    
print("Number of images per report: ", Counter(n_imgs_per_report))

In [None]:
indications = []
findings = []
imgs_paths = []

for rep in rep_list:
    root = ET.parse(reports_path + "/" + rep).getroot()
    imgs = root.findall(".//parentImage")
    n_images = len(imgs)
    if n_images == 0:
        continue
    else:
        description = root.findall(".//AbstractText")
        indication = description[1].text
        finding = ""
        for d in description[2:]:
            try:
                finding += ". " + d.text
            except:
                pass
        indications.append(indication)
        indications.append(indication)
        findings.append(finding)
        findings.append(finding)

        if n_images >= 2:
            imgs_paths.append(imgs[0].attrib['id'])
            imgs_paths.append(imgs[1].attrib['id'])
        if n_images == 1:
            imgs_paths.append(imgs[0].attrib['id'])
            imgs_paths.append(imgs[0].attrib['id'])

## Dataset loader

In [9]:
spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:
    def __init__(self, freq_thresh):
        self.index_to_string = {0: "<PAD>", 1: "<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.string_to_index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_thresh = freq_thresh
        
    def __len__(self):
        return len(self.index_to_string)
    
    @staticmethod
    def tokenizer(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def return_index_to_string(self):
        return self.index_to_string
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        """for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                
                if frequencies[word] == self.freq_thresh:
                    self.string_to_index[word] = idx
                    self.index_to_string[idx] = word
                    idx += 1"""
        for word in self.tokenizer(sentence_list):
            if word not in frequencies:
                frequencies[word] = 1
            else:
                frequencies[word] += 1
                
            if frequencies[word] == self.freq_thresh:
                self.string_to_index[word] = idx
                self.index_to_string[idx] = word
                idx += 1
                    
    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        
        return [self.string_to_index[token] if token in self.string_to_index else self.string_to_index["<UNK>"] 
                for token in tokenized_text]
    
class MyCollate:
        def __init__(self, pad_idx):
            self.pad_idx = pad_idx
            
        def __call__(self, batch):
            imgs = [item[0].unsqueeze(0) for item in batch]
            imgs = torch.cat(imgs, dim=0)
            targets = [item[1] for item in batch]
            targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
            
            return imgs, targets

In [31]:
class XRayDataset(Dataset):
    def __init__(self, csv_file, path,transform, freq_thresh=3, size=(624,512)):
        #path is for general folder, csv_file is this file
        self.path = path
        self.dataframe = csv_file
        self.size = size
        self.transform = transform
        
        self.img_col = self.dataframe["Imgs_paths"]
        self.findings_col = self.dataframe["findings"]
        
        self.vocab = Vocabulary(freq_thresh)
        self.st = ""
        self.vocab.build_vocabulary(self.st.join(self.findings_col.tolist()[:]))
        
    def __len__(self):
        return len(self.dataframe) #7702
    
    def __getitem__(self, index):
        finding = self.findings_col[index]
        img_id = self.img_col[index]
        img_path = self.path + "/Images/" + img_id + ".png"
        #img = Image.open(img_path).convert('L').resize(self.size) this will be used when I addapt model to grayscale imgs
        img = Image.open(img_path).resize(self.size)
        
        if self.transform is not None:
            img = self.transform(img)
        
        numericalized_caption = [self.vocab.string_to_index["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(finding)
        numericalized_caption.append(self.vocab.string_to_index["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)


In [32]:
def get_loader(csv_file, path, transform, batch_size=32, shuffle=True):
    dataset = XRayDataset(csv_file, path, transform)
    
    pad_idx = dataset.vocab.string_to_index["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=MyCollate(pad_idx=pad_idx)
    )
    
    return loader, dataset

transform = transforms.Compose(
[
    transforms.ToTensor()
])

## Network

In [33]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=False) #zmienic ten model pozniej
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, images):
        features = self.inception(images)
        
        for name, param in self.inception.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = self.train_CNN
                
        return self.dropout(self.relu(features))
    
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs
    
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
        
    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs
    
    def caption_image(self, image, vocabulary, max_length=50): #here max length shoudl be changed
        result_caption = []
        
        with torch.no_grad():
            X = self.encoderCNN(image).unsqueeze(0)
            states = None
            
            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)
                
                if vocabulary.index_to_string[predicted.item()] == "<EOS>":
                    break
        return [vocabulary.index_to_string[idx] for idx in result_caption]

In [64]:
def save_checkpoint(state, file_path):
    print("Saving")
    torch.save(state, file_path)

def train(embed_size, hidden_size, vocab_size, num_layers, learning_rate, num_epochs, model, criterion, optimizer,train_loader,valid_loader, device, file_path):
    transform = transforms.Compose(
    [
        transforms.Resize((299,299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    loss_history = []
    valid_loss_history = []

    load_model = False
    save_model = True

    model.train()
    step = 0
    
    for epoch in range(num_epochs):
        checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
        model.train()

        # train set
        epoch_train_loss = []
        for batch_idx, (imgs, captions) in enumerate(train_loader):
            imgs = imgs.to(device)
            captions = captions.to(device)
            
            outputs = model(imgs, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
            
            step += 1
            
            optimizer.zero_grad()
            loss.backward(loss)
            epoch_train_loss.append(loss)
            optimizer.step()
            
            output.clear()
            print("Example {}/{}. Epoch {}/{}. Loss: {}".format(batch_idx, len(train_loader), epoch+1, num_epochs, loss))
            #if batch_idx >= 2:
            #    break
        if save_model:
              save_checkpoint(checkpoint, file_path)
        loss_history.append(np.sum(epoch_train_loss)/(batch_idx+1)) #tutaj sie chyba dodaje tylko ostatnia strata a nie srednia z epoki

        # validation set
        model.eval()
        epoch_valid_loss = []
        for batch_idx, (imgs, captions) in enumerate(valid_loader):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
            epoch_valid_loss.append(loss)
            step += 1

            output.clear()
            print("Valid example {}/{}. Epoch {}/{}. Loss: {}".format(batch_idx, len(train_loader), epoch+1, num_epochs, loss))
            #if batch_idx >= 2:
            #    break
        valid_loss_history.append(np.sum(epoch_valid_loss)/(batch_idx+1))

    return loss_history, valid_loss_history

In [65]:
path = '/content/drive/MyDrive/Data/XrayNLP'
csv_path = path + "/" + "dataframe.csv"
checkpoint_path = path + "/" + "checkpoint.pth.tar"

df = pd.read_csv(csv_path)
csv_train, csv_validate, csv_test = np.split(df.sample(frac=1, random_state=random_state), [int(.8*len(df)), int(.9*len(df))])
csv_train, csv_validate, csv_test = csv_train.reset_index(), csv_validate.reset_index(), csv_test.reset_index()

_, dataset = get_loader(df, path, transform)
train_loader, train_dataset = get_loader(csv_train, path, transform)
valid_loader, valid_dataset = get_loader(csv_validate, path, transform)
test_loader, test_dataset = get_loader(csv_test, path, transform)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_size = 256
hidden_size = 256
vocab_size = len(train_dataset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs=5
    
step = 0
    
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.vocab.string_to_index["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

loss_history, valid_loss_history = train(embed_size, hidden_size, vocab_size, num_layers, learning_rate, num_epochs, model, criterion, optimizer,train_loader,valid_loader, device, checkpoint_path)

Valid example 23/189. Epoch 5/5. Loss: 6.08713960647583


In [66]:
loss_history

[tensor(4.6280, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(3.2495, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(2.8621, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(2.6376, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(2.4853, device='cuda:0', grad_fn=<DivBackward0>)]

In [67]:
valid_loss_history

[tensor(5.1149, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(5.4642, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(5.7283, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(5.9285, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(6.0927, device='cuda:0', grad_fn=<DivBackward0>)]

Loss at the beginning: 7.2701191902160645

In [68]:
model.eval()
outputs_list = []

with torch.no_grad():
      for batch_idx, (imgs, captions) in enumerate(test_loader):
            imgs = imgs.to(device)
            captions = captions.to(device)
            
            outputs = model(imgs, captions[:-1])
            outputs_list.append(outputs)            
            #writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1
            if batch_idx >= 1:
                break

In [180]:
np.shape(outputs_list[0][:,0,:])

torch.Size([102, 1478])

In [69]:
index_dict = Vocabulary.return_index_to_string(dataset.vocab)

result = ""
for value in np.argmax(outputs_list[1][:,20,:].tolist(), axis=1):
    result += " " + index_dict[value]

result

' <SOS> . xxxx silhouette of soft size xxxx airspace pneumothorax limits no . normal . of normal . . is 2 there pneumothorax airspace 2 . is focal right . evidence disease density . xxxx opacity size size size size size are . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .'

In [10]:
path = '/content/drive/MyDrive/Data/XrayNLP'
csv_path = path + "/" + "dataframe.csv"
checkpoint_path = path + "/" + "checkpoint.pth.tar"

train_loader, dataset = get_loader(csv_path, path, transform)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs=2
    
step = 0

model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.string_to_index["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
def load_checkpoint(checkpoint_path):
    print("Loading checkpoint")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

In [15]:
load_checkpoint(checkpoint_path)

Loading checkpoint
