# Imports

In [16]:
import copy
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')

import pandas as pd
import re
import time
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from fastai.vision import transform
from PIL import Image, ImageFilter, ImageEnhance
from scipy import stats
from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader, Sampler
 
from torchvision import transforms, models
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
cpu = torch.device('cpu')

# Looking at Data

In [4]:
!ls MURA-v1.1/

train		       train_labeled_studies.csv  valid_image_paths.csv
train_image_paths.csv  valid			  valid_labeled_studies.csv


In [5]:
train_labled_studies = pd.read_csv('MURA-v1.1/valid_image_paths.csv', header=None)
train_labled_studies.head()

Unnamed: 0,0
0,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
1,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
2,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
3,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
4,MURA-v1.1/valid/XR_WRIST/patient11186/study1_p...


In [6]:
train_labled_studies.loc[0].iloc[0]

'MURA-v1.1/valid/XR_WRIST/patient11185/study1_positive/image1.png'

In [7]:
del train_labled_studies

In [8]:
class MuraDatasetByStudy(Dataset):
    def __init__(self, csv_file, transform):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable): Transform to be applied on a sample.
        """
        
        df = pd.read_csv(csv_file, header=None)
        self.study_pths = df.iloc[:,0]
        self.labels = df.iloc[:, 1]
        self.transform = transform

    def __len__(self):
        return len(self.study_pths)
    
    def __getitem__(self, idx):
        study_path = self.study_pths[idx]
        
        # get images for this study
        i = 0
        images = []
        while(True):
            img_name = study_path + 'image%s.png' % (i+1)
            try:
                img = Image.open(img_name)
                img = img.convert('RGB')
                images.append(self.transform(img))
                i += 1
            except FileNotFoundError:
                break
        
        images = torch.stack(images)
        label = self.labels[idx]
        
        m = re.match(r'.*/XR_(\w*)/patient(\d+)/study(\d+)_', img_name)
        study_type = m.group(1)
        study_id = m.group(2)+'/'+m.group(3)
        
        return images, study_type, study_id, label
    
    def __iter__(self):
        self.idx = 0
        return self
    
    def __next__(self):
        if self.idx >= len(self):
            raise StopIteration
        
        result = self[self.idx]
        self.idx += 1
        return result

In [9]:
class MuraDataset(Dataset):
    def __init__(self, csv_file, transform, augment_transforms=[]):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        traindf = pd.read_csv(csv_file, header=None)
        self.train_img_pths = traindf.iloc[:,0]
        self.transform = transform
        self.augment_transforms = augment_transforms

    def __len__(self):
        return (len(self.augment_transforms)+1)*len(self.train_img_pths)
    
    def __getitem__(self, idx):
        if idx >= len(self.train_img_pths):
            idx = idx%len(self.train_img_pths)
            transform = self.augment_transforms[idx//len(self.train_img_pths)-0]
        else:
            transform = self.transform

        img_name = self.train_img_pths[idx]
        img = Image.open(img_name)
        img = img.convert('RGB')
        img = transform(img)
        
        img_study_type = re.match(r'.*/XR_(\w*)/', img_name).group(1)
        img_label = int(
            re.match(r'.*/study\d+_(\w+)/', img_name).group(1) == 'positive'
        )
        img_study_name = re.match(r'(.*/).*png', img_name).group(1)
        
        return img, img_study_name, img_study_type, img_label
    
    def __iter__(self):
        self.idx = 0
        return self
    
    def __next__(self):
        if self.idx >= len(self):
            raise StopIteration
        
        result = self[self.idx]
        self.idx += 1
        return result

# Model

In [10]:
batch_size = 8

# Setup the loss fxn
criterion = nn.CrossEntropyLoss()

model = models.resnet18(pretrained=False)
# freeze everything to extract feature vector
# for param in model.parameters():
#     param.requires_grad = False

# replace linear layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

optimizer = optim.Adam(model.parameters(), lr=0.2)

In [11]:
input_size = (224, 224)
data_transforms = {
        'train': transforms.Compose([
                transforms.RandomAffine(
                    degrees=30,
                    translate=(0, 0.2),
                    scale=(1, 1.5),
                    shear=10,
                ),
                transforms.Resize(input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'valid': transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
    }


datasets_dict = {
    x: MuraDataset(
        csv_file='MURA-v1.1/%s_image_paths.csv' % x,
        transform=data_transforms[x],
        augment_transforms=[],
    ) for x in ['train', 'valid']
}

# # subset
# dataloaders_dict = {
#     x: DataLoader(
#         torch.utils.data.Subset(
#             dataset=datasets_dict[x],
#             indices=np.random.randint(
#                 len(datasets_dict[x]),
#                 size=100,
#             ),
#         ),
#         batch_size=batch_size,
#         num_workers=8,
#     ) for x in ['train', 'valid']
# }

dataloaders_dict = {
    x: DataLoader(
        datasets_dict[x],
        shuffle=True,
        batch_size=batch_size,
        num_workers=8,
    ) for x in ['train', 'valid']
}

In [12]:
def update_learning_rate(optimizer, update):
    for param_group in optimizer.param_groups:
        param_group['lr'] = update(param_group['lr'])

In [22]:
epoch_n = 7
M = 5
T = len(dataloaders_dict['train'].dataset)/batch_size*epoch_n*M

def shifted_cosine_function(t, lr):
    return lr/2*(np.cos(np.pi*((t-1)%np.ceil(T/M))/np.ceil(T/M))+1)

In [23]:
def train_model(model, train_loader, criterion, optimizer, epoch_n, validation_loader):
    model = model.to(device)
    last_valid_acc = 0
    
    for epoch in range(epoch_n):
        model.train()
        print('=========================epoch %i=====================' % epoch)
        epoch_loss = 0.0
        running_loss = 0.0
        epoch_corrects = 0
        for i, data in enumerate(tqdm(train_loader)):
            inputs, _, study_type, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.sum().backward()
            optimizer.step()
            _, preds = torch.max(outputs, 1)
            epoch_corrects += torch.sum(preds == labels.data)

            # print statistics
            running_loss += loss.sum()
            epoch_loss += loss.sum()

        epoch_loss = epoch_loss / len(train_loader.dataset)
        epoch_accuracy = epoch_corrects.double() / len(train_loader.dataset)
        print('epoch loss: %.8f' % epoch_loss)
        print('epoch accuracy: %.8f' % epoch_accuracy)

        # else evaluate first then keep training
        _, valid_loss, valid_acc, valid_kappa = validate_model(
            [model],
            validation_loader,
            criterion,
            optimizer,
        )
        
        print('epoch valid_loss: %8f' % valid_loss)
        print('epoch valid_acc: %.8f' % valid_acc)
        print('last epoch valid_acc: %.8f' % last_valid_acc)
        print('epoch valid_kappa: %.8f' % valid_kappa)

        if abs(valid_acc - last_valid_acc) < 0.0001:
            update_learning_rate(optimizer, update= lambda x: 0.1*x)
        last_valid_acc = valid_acc
    return model

In [24]:
def train_model_snapshot_ensembling(model, train_loader, criterion, optimizer, epoch_n, cycles, validation_loader):
    model = model.to(device)
    t = 0 # Total number of iterations
    models = []
    for cycle in range(cycles):
        for epoch in range(epoch_n):
            model.train()
            print('=========================epoch %i=====================' % epoch)
            epoch_loss = 0.0
            running_loss = 0.0
            epoch_corrects = 0
            for i, data in enumerate(tqdm(train_loader)):
                # update learning rate
                t += 1
                update_learning_rate(optimizer, lambda lr: shifted_cosine_function(t, 0.2))

                inputs, _, study_type, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                loss.sum().backward()
                optimizer.step()
                _, preds = torch.max(outputs, 1)
                epoch_corrects += torch.sum(preds == labels.data)

                # print statistics
                running_loss += loss.sum()
                epoch_loss += loss.sum()

            epoch_loss = epoch_loss / len(train_loader.dataset)
            epoch_accuracy = epoch_corrects.double() / len(train_loader.dataset)
            print('epoch loss: %.8f' % epoch_loss)
            print('epoch accuracy: %.8f' % epoch_accuracy)

            # else evaluate first then keep training
            _, valid_loss, valid_acc, valid_kappa = validate_model(
                [model],
                validation_loader,
                criterion,
                optimizer,
            )
            print('epoch valid_loss: %8f' % valid_loss)
            print('epoch valid_acc: %.8f' % valid_acc)
            print('epoch valid_kappa: %.8f' % valid_kappa)
        
        print('cycle %i is over.' % cycle)
        print('saving model..\n')
        models.append(model.state_dict())
    return models 

In [19]:
def validate_model(models, validation_loader, criterion, optimizer):
    """
    returns (df, acc, kappa)
        df: Pandas dataframe containing study paths and their corresponding label
        obtained by averaging the model scores on the study's images and treating any
        score higher than 0.5 as abnormal.
        
        acc: model accuracy score on studies
        kappa: model kappa score on studies
    """
    for model in models:
        model = model.to(device)
        model.eval()
    total_loss = 0.0
    
    all_study_names = []
    all_outputs = []
    
    for i, data in enumerate(tqdm(validation_loader)):
        inputs, study_names, study_types, labels = data
        all_study_names += study_names
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(False):
            ensembled_outputs = []
            for model in models:
                outputs = model(inputs)
                ensembled_outputs.append(outputs)
            outputs = torch.stack(ensembled_outputs)
            outputs = torch.mean(outputs, 0)
            all_outputs += outputs.tolist()
            loss = criterion(outputs, labels)
        
        total_loss += loss.sum()
    
    total_loss /= len(validation_loader.dataset)
    results_df = pd.DataFrame({
        'study': all_study_names,
        'prediction': [np.argmax(o) for o in all_outputs]
    })
    
    predicted_labeled_studies = results_df.groupby('study').agg({'prediction': lambda x: stats.mode(x, axis=None)[0][0]})
    predicted_labeled_studies = predicted_labeled_studies.reset_index()

    valid_labeled_studies = pd.read_csv(
        'MURA-v1.1/valid_labeled_studies.csv',
        header=None,
        names= ['study', 'label']
    )
    
   
    joined_df = predicted_labeled_studies.join(valid_labeled_studies.set_index('study'), on='study')
    
    
    accuracy = accuracy_score(joined_df.label, joined_df.prediction)
    kappa = cohen_kappa_score(joined_df.label, joined_df.prediction)
    from sklearn.metrics import confusion_matrix
    confusion_m = confusion_matrix(joined_df.label, joined_df.prediction)
    
    return joined_df, total_loss.item(), accuracy, kappa, confusion_m

In [None]:
model = train_model_snapshot_ensembling(model, dataloaders_dict['train'], criterion, optimizer, epoch_n, M, dataloaders_dict['valid'])

  0%|          | 0/4601 [00:00<?, ?it/s]



100%|██████████| 4601/4601 [03:31<00:00, 21.72it/s]
  0%|          | 0/400 [00:00<?, ?it/s]

epoch loss: 0.09168026
epoch accuracy: 0.56732232


100%|██████████| 400/400 [00:10<00:00, 38.37it/s]
  0%|          | 0/4601 [00:00<?, ?it/s]

epoch valid_loss: 0.086643
epoch valid_acc: 0.55129274
epoch valid_kappa: 0.00000000


100%|██████████| 4601/4601 [02:20<00:00, 39.57it/s]
  0%|          | 0/400 [00:00<?, ?it/s]

epoch loss: 0.08751455
epoch accuracy: 0.56797435


100%|██████████| 400/400 [00:08<00:00, 47.64it/s]
  0%|          | 0/4601 [00:00<?, ?it/s]

epoch valid_loss: 0.087199
epoch valid_acc: 0.55129274
epoch valid_kappa: 0.00000000


100%|██████████| 4601/4601 [02:22<00:00, 32.36it/s]
  0%|          | 0/400 [00:00<?, ?it/s]

epoch loss: 0.08585342
epoch accuracy: 0.57737448


100%|██████████| 400/400 [00:08<00:00, 47.48it/s]
  0%|          | 0/4601 [00:00<?, ?it/s]

epoch valid_loss: 0.087403
epoch valid_acc: 0.55129274
epoch valid_kappa: 0.00000000


100%|██████████| 4601/4601 [02:21<00:00, 32.50it/s]
  0%|          | 0/400 [00:00<?, ?it/s]

epoch loss: 0.08535801
epoch accuracy: 0.58506303


100%|██████████| 400/400 [00:08<00:00, 47.93it/s]
  0%|          | 0/4601 [00:00<?, ?it/s]

epoch valid_loss: 0.086623
epoch valid_acc: 0.55129274
epoch valid_kappa: 0.00000000


100%|██████████| 4601/4601 [02:21<00:00, 32.48it/s]
  0%|          | 0/400 [00:00<?, ?it/s]

epoch loss: 0.08498706
epoch accuracy: 0.58832319


100%|██████████| 400/400 [00:08<00:00, 47.83it/s]
  0%|          | 0/4601 [00:00<?, ?it/s]

epoch valid_loss: 0.086676
epoch valid_acc: 0.55129274
epoch valid_kappa: 0.00000000


100%|██████████| 4601/4601 [02:20<00:00, 32.65it/s]
  0%|          | 0/400 [00:00<?, ?it/s]

epoch loss: 0.08456496
epoch accuracy: 0.59506086


100%|██████████| 400/400 [00:08<00:00, 47.53it/s]
  0%|          | 0/4601 [00:00<?, ?it/s]

epoch valid_loss: 0.087564
epoch valid_acc: 0.55129274
epoch valid_kappa: 0.00000000


 57%|█████▋    | 2612/4601 [01:20<00:59, 33.28it/s]

In [None]:
def get_model():
    m = models.resnet18(pretrained=False)
    num_ftrs = m.fc.in_features
    m.fc = nn.Linear(num_ftrs, 2)
    return m

m1 = get_model().load_state_dict(model[0])

In [None]:
best_models = [None, None, None, None, None]
for i in range(5):
    best_models[i] = get_model()
    best_models[i].load_state_dict(model[i])

In [14]:
best_models = [torch.load('resnet_model_crop')]

In [20]:
_, valid_loss, valid_acc, valid_kappa, confusion_matrix = validate_model(
        best_models,
        dataloaders_dict['valid'],
        criterion,
        optimizer,
    )

print('epoch valid_loss: %8f' % valid_loss)
print('epoch valid_acc: %.8f' % valid_acc)
print('epoch valid_kappa: %.8f' % valid_kappa)
print(confusion_matrix)

100%|██████████| 400/400 [00:08<00:00, 49.03it/s]


epoch valid_loss: 0.069042
epoch valid_acc: 0.74311927
epoch valid_kappa: 0.45926618
[[622  39]
 [269 269]]
