In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [1]:
import torch
import pandas as pd
import os
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import torch.optim as optim
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification


def test(model, test_loader, criterion, device, n_way):
    # An F1 Score of 0 indicates that it is invalid
    model.eval()
    true_positive = list(0. for i in range(n_way))  # Number of correctly predicted samples per class
    total_truth = list(0. for i in range(n_way))  # Number of ground truths per class
    predicted_positive = list(0. for i in range(n_way))  # Number of predicted samples per class
    precision = list(0. for i in range(n_way))
    recall = list(0. for i in range(n_way))
    class_f1 = list(0. for i in range(n_way))
    val_loss = 0
    correct_total = 0  # Total correctly predicted samples
    total = 0  # Total samples
    f1_flag = 0  # Flag for invalid F1 score
    with torch.no_grad():
        for step, (data_inputs, data_labels) in enumerate(test_loader):
            inputs, labels = data_inputs.to(device), data_labels.to(device)
            pred = model(inputs)
            loss = criterion(pred, labels)
            val_loss += loss.item()  # Running validation loss
            _, predicted = torch.max(pred, 1)
            correct = (predicted == labels).squeeze()  # Samples that are correctly predicted
            correct_total += (predicted == labels).sum().item()
            total += labels.size(0)

            for i in range(len(predicted)):
                label = labels[i]
                true_positive[label] += correct[i].item()
                total_truth[label] += 1
                predicted_positive[predicted[i].item()] += 1  # True Positive + False Positive

        # Find class accuracy, precision and recall
        for j in range(n_way):
            if (predicted_positive[j] != 0 and true_positive[j] != 0):  # Check if F1 score is valid
                precision[j] = true_positive[j] / predicted_positive[j]
                recall[j] = true_positive[j] / total_truth[j]  # Recall is the same as per class accuracy
                class_f1[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
            else:
                f1_flag = 1

        # Find Accuracy, Macro Accuracy and Macro F1 Score
        macro_acc_sum = 0
        f1_sum = 0
        for k in range(n_way):
            macro_acc_sum += recall[k]
            if f1_flag == 0:  # Check for invalid f1 score
                f1_sum += class_f1[k]

        accuracy = correct_total / total
        macro_accuracy = macro_acc_sum / n_way
        f1_score = f1_sum / n_way

    return val_loss / (step+1), accuracy, macro_accuracy, f1_score, class_f1

class MimicCxrReports(Dataset):
    """
    MIMIC-CXR Database, Reports Only
    Todo: Insert references to the database here!
    """
    
    def __init__(self, root, csv_path, tokenizer, mode):
        
        # Check if mode contains an accepted value
        if mode not in ('base_train', 'base_validate', 'novel_train', 'novel_validate'):
            raise Exception("Selected 'mode' is not valid")
            
        self.root = root
        csv_data = pd.read_csv(csv_path)
        csv_data = csv_data[csv_data.split == mode]
        
        if mode == 'base_train' or mode == 'base_validate':
            dict_labels = {
                'Atelectasis': 0,
                'Cardiomegaly': 1,
                'Consolidation': 2,
                'Edema': 3,
                'Fracture': 4,
                'Lung Opacity': 5,
                'No Finding': 6,
                'Pneumonia': 7,
                'Pneumothorax': 8,
                'Support Devices': 9
            }
        else:
            dict_labels = {
                'Enlarged Cardiomediastinum': 0,
                'Lung Lesion': 1,
                'Pleural Effusion': 2,
            }
            
            # Get text encodings and labels
            texts = []
            labels = []
            for index, row in csv_data.iterrows():
                text_name = f'{row["file_path"].split("/")[2]}.txt' # Only the study id is required to find the report
                text_path = Path(os.path.join(self.root, text_name))
                texts.append(text_path.read_text())
                labels.append(dict_labels[row['labels']])
            self.labels = labels
            self.encodings = tokenizer(texts, truncation=True, padding=True)
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

root='../../../../scratch/rl80/mimic-cxr-2.0.0.physionet.org'
csv_path = '../splits/20_shot.csv'
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_dataset = MimicCxrReports(root, csv_path, tokenizer, mode='novel_train')
test_dataset = MimicCxrReports(root, csv_path, tokenizer, mode='novel_validate')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
model.to(device)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)
#optimizer = AdamW(model.parameters(), lr=5e-5)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(3):
    model.train()
    train_loss = 0
    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        train_loss += loss.item()  # Running training loss
#model.eval()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [2]:
outputs.loss

AttributeError: 'tuple' object has no attribute 'loss'

In [3]:
texts = []
labels = []
for index, row in csv_data.iterrows():
    text_name = f'{row["file_path"].split("/")[2]}.txt' # Only the study id is required to find the report
    text_path = Path(os.path.join(root, text_name))
    texts.append(text_path.read_text())
    #labels.append(dict_labels[row['labels']])
print(len)

                                 FINAL REPORT
 HISTORY:  Evaluate for the pneumothorax, pigtail catheter connected to
 Pleurovac now with leak.
 
 COMPARISON:  ___ at 7:14.
 
 FINDINGS:
 
 The left pigtail catheter, right chest port and AICD leads are in unchanged
 position.  A lucency along the left mediastinum could represent medial
 pneumothorax, not significantly changed from earlier exam.  Otherwise, no
 significant change in bilateral pleural effusions.  No focal consolidation is
 present.  No evidence of pulmonary vascular congestion.
 
 IMPRESSION:
 
 Lucency along the left mediastinum could represent medial pneumothorax, not
 significantly changed from earlier radiograph.  Otherwise, no significant
 change from prior radiographs.
 
 NOTIFICATION:  Findings discussed with Dr. ___ by Dr. ___ at 14:30
 on ___.



In [6]:
labels.size()

torch.Size([12])