# Imports

In [1]:
import copy
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')

import pandas as pd
import re
import time
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from inception import inception_v3
from PIL import Image
from sklearn.metrics import accuracy_score, cohen_kappa_score
from torch.utils.data import Dataset, DataLoader, Sampler
 
from torchvision import transforms, models
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
cpu = torch.device('cpu')

# Looking at Data

In [4]:
!ls MURA-v1.1/

train		       train_labeled_studies.csv  valid_image_paths.csv
train_image_paths.csv  valid			  valid_labeled_studies.csv


In [5]:
train_labled_studies = pd.read_csv('MURA-v1.1/valid_image_paths.csv', header=None)
train_labled_studies.head()

Unnamed: 0,0
0,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
1,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
2,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
3,MURA-v1.1/valid/XR_WRIST/patient11185/study1_p...
4,MURA-v1.1/valid/XR_WRIST/patient11186/study1_p...


In [6]:
train_labled_studies.loc[0].iloc[0]

'MURA-v1.1/valid/XR_WRIST/patient11185/study1_positive/image1.png'

In [7]:
class MuraDatasetByStudy(Dataset):
    def __init__(self, csv_file, transform):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable): Transform to be applied on a sample.
        """
        
        df = pd.read_csv(csv_file, header=None)
        self.study_pths = df.iloc[:,0]
        self.labels = df.iloc[:, 1]
        self.transform = transform

    def __len__(self):
        return len(self.study_pths)
    
    def __getitem__(self, idx):
        study_path = self.study_pths[idx]
        
        # get images for this study
        i = 0
        images = []
        while(True):
            img_name = study_path + 'image%s.png' % (i+1)
            try:
                img = Image.open(img_name)
                img = img.convert('RGB')
                images.append(self.transform(img))
                i += 1
            except FileNotFoundError:
                break
        
        images = torch.stack(images)
        label = self.labels[idx]
        
        m = re.match(r'.*/XR_(\w*)/patient(\d+)/study(\d+)_', img_name)
        study_type = m.group(1)
        study_id = m.group(2)+'/'+m.group(3)
        
        return images, study_type, study_id, label
    
    def __iter__(self):
        self.idx = 0
        return self
    
    def __next__(self):
        if self.idx >= len(self):
            raise StopIteration
        
        result = self[self.idx]
        self.idx += 1
        return result

In [8]:
class MuraDataset(Dataset):
    def __init__(self, csv_file, transform, augment_transforms=[]):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        traindf = pd.read_csv(csv_file, header=None)
        self.train_img_pths = traindf.iloc[:,0]
        self.transform = transform
        self.augment_transforms = augment_transforms

    def __len__(self):
        return (len(self.augment_transforms)+1)*len(self.train_img_pths)
    
    def __getitem__(self, idx):
        if idx >= len(self.train_img_pths):
            idx = idx%len(self.train_img_pths)
            transform = self.augment_transforms[idx//len(self.train_img_pths)-0]
        else:
            transform = self.transform

        img_name = self.train_img_pths[idx]
        img = Image.open(img_name)
        img = img.convert('RGB')
        img = transform(img)
        
        img_study_type = re.match(r'.*/XR_(\w*)/', img_name).group(1)
        img_label = int(
            re.match(r'.*/study\d+_(\w+)/', img_name).group(1) == 'positive'
        )
        img_study_name = re.match(r'(.*/).*png', img_name).group(1)
        
        return img, img_study_name, img_study_type, img_label
    
    def __iter__(self):
        self.idx = 0
        return self
    
    def __next__(self):
        if self.idx >= len(self):
            raise StopIteration
        
        result = self[self.idx]
        self.idx += 1
        return result

# Custom Loss Function

In [9]:
study_counts = pd.read_csv('study_counts.csv')
study_counts = study_counts.set_index('study_type')

In [10]:
def weights(study_type):
    Nt = study_counts.loc[study_type].loc['label_positive']
    At = study_counts.loc[study_type].loc['label_negative']
    return At/(At+Nt), Nt/(At+Nt), 

In [11]:
Wt = {t : weights(t) for t in study_counts.index}
Wt

{'ELBOW': (0.593185966335429, 0.4068140336645711),
 'FINGER': (0.6145710928319624, 0.3854289071680376),
 'FOREARM': (0.6378082191780822, 0.36219178082191783),
 'HAND': (0.7322749413674905, 0.26772505863250945),
 'HUMERUS': (0.5290880503144654, 0.47091194968553457),
 'SHOULDER': (0.5025659386561642, 0.4974340613438358),
 'WRIST': (0.591160787530763, 0.4088392124692371)}

In [12]:
class WeightedCrossEntropyLoss(torch.nn.modules.Module):
    def __init__(self, W, cpu=False):
        super(WeightedCrossEntropyLoss, self).__init__()
        self.W = W
        self.T = torch.DoubleTensor if cpu else torch.cuda.DoubleTensor
    
    def forward(self, inputs, targets, keys):
        Wt1 = self.T([self.W[key][1] for key in keys])
        Wt0 = self.T([self.W[key][0] for key in keys])
        targets = targets.double()
        inputs = inputs.double()
        print(inputs)
        print(inputs.log())
#         loss = - (Wt1*targets*inputs.log() + Wt0*(1-targets)*(1-inputs).log())
        return F.binary_cross_entropy(inputs, targets, weight=Wt1)

In [13]:
del study_counts

# Model

In [35]:
batch_size = 8

# Setup the loss fxn
criterion = nn.BCELoss()


model = inception_v3(pretrained=False, num_classes=1)
# freeze everything to extract feature vector
# for param in model.parameters():
#     param.requires_grad = False

# Handle the auxilary net
num_ftrs = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Linear(num_ftrs, 1)

# Handle the primary net
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

optimizer = optim.Adam(model.parameters(), lr=0.2)

In [36]:
input_size = (299, 299)
data_transforms = {
        'train': transforms.Compose([
                transforms.Resize(input_size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'valid': transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
    }

augment_transforms = [
    transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
]


datasets_dict = {
    x: MuraDataset(
        csv_file='MURA-v1.1/%s_image_paths.csv' % x,
        transform=data_transforms[x],
        augment_transforms=augment_transforms if x == 'train' else [],
    ) for x in ['train', 'valid']
}

# # subset
# dataloaders_dict = {
#     x: DataLoader(
#         torch.utils.data.Subset(
#             dataset=datasets_dict[x],
#             indices=np.random.randint(
#                 len(datasets_dict[x]),
#                 size=100,
#             ),
#         ),
#         batch_size=batch_size,
#         num_workers=8,
#     ) for x in ['train', 'valid']
# }

dataloaders_dict = {
    x: DataLoader(
        datasets_dict[x],
        shuffle=True,
        batch_size=batch_size,
        num_workers=8,
    ) for x in ['train', 'valid']
}

In [37]:
def update_learning_rate(optimizer, update):
    for param_group in optimizer.param_groups:
        param_group['lr'] = update(param_group['lr'])

In [38]:
epoch_n = 8
T = batch_size*epoch_n
M = 5

def shifted_cosine_function(t, lr):
    return lr/2*(np.cos(np.pi*((t-1)%np.ceil(T/M))/np.ceil(T/M))+1)

In [39]:
def train_model(model, train_loader, criterion, optimizer, epoch_n, cycles, validation_loader):
    model = model.to(device)
    t = 0 # Total number of iterations
    models = []
    for cycle in range(cycles):
        # reset learning rate
        update_learning_rate(optimizer, lambda lr: 0.2)
        for epoch in range(epoch_n):
            model.train()
            print('=========================epoch %i=====================' % epoch)
            epoch_loss = 0.0
            running_loss = 0.0
            epoch_corrects = 0
            for i, data in enumerate(tqdm(train_loader)):
                # update learning rate
                t += 1
                update_learning_rate(optimizer, lambda lr: shifted_cosine_function(t, lr))

                inputs, _, study_type, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                Wt1 = torch.FloatTensor([Wt[key][1] for key in study_type])

                optimizer.zero_grad()

                outputs, aux_outputs = model(inputs)
                outputs = torch.mean(outputs, 1)
                aux_outputs = torch.mean(aux_outputs, 1)
                loss1 = criterion(outputs, labels.float())
                loss2 = criterion(aux_outputs, labels.float())
                loss = loss1 + 0.4*loss2
                
                loss.sum().backward()
                optimizer.step()
                preds = (outputs > 0.5).type(torch.cuda.LongTensor)
                epoch_corrects += torch.sum(preds == labels.data)

                # print statistics
                running_loss += loss.sum()
                epoch_loss += loss.sum()

            epoch_loss = epoch_loss / len(train_loader.dataset)
            epoch_accuracy = epoch_corrects.double() / len(train_loader.dataset)
            print('epoch loss: %.8f' % epoch_loss)
            print('epoch accuracy: %.8f' % epoch_accuracy)

            # else evaluate first then keep training
            _, valid_loss, valid_acc, valid_kappa = validate_model(
                [model],
                validation_loader,
                criterion,
                optimizer,
            )
            print('epoch valid_loss: %8f' % valid_loss)
            print('epoch valid_acc: %.8f' % valid_acc)
            print('epoch valid_kappa: %.8f' % valid_kappa)
        
        print('cycle %i is over.' % cycle)
        print('saving model..\n')
        models.append(model.state_dict())
    return models 

In [40]:
def validate_model(models, validation_loader, criterion, optimizer):
    """
    returns (df, acc, kappa)
        df: Pandas dataframe containing study paths and their corresponding label
        obtained by averaging the model scores on the study's images and treating any
        score higher than 0.5 as abnormal.
        
        acc: model accuracy score on studies
        kappa: model kappa score on studies
    """
    for model in models:
        model = model.to(device)
        model.eval()
    total_loss = 0.0
    
    all_study_names = []
    all_outputs = []
    
    for i, data in enumerate(tqdm(validation_loader)):
        inputs, study_names, study_types, labels = data
        all_study_names += study_names
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(False):
            ensembled_outputs = []
            for model in models:
                outputs = model(inputs)
                outputs = torch.mean(outputs, 1)
                ensembled_outputs.append(outputs)
            outputs = torch.stack(ensembled_outputs)
            outputs = torch.mean(outputs, 0)
            all_outputs += outputs.tolist()
            loss = criterion(outputs, labels.float())
        
        total_loss += loss.sum()
    
    total_loss /= len(validation_loader.dataset)
    results_df = pd.DataFrame({
        'study': all_study_names,
        'prob': all_outputs,
    })
    
    predicted_labeled_studies = results_df.groupby('study').agg({'prob': 'mean'})
    predicted_labeled_studies = predicted_labeled_studies.reset_index()
    
    predicted_labeled_studies['prediction'] = (predicted_labeled_studies['prob'] > 0.5).astype('int8')
    
    valid_labeled_studies = pd.read_csv(
        'MURA-v1.1/valid_labeled_studies.csv',
        header=None,
        names= ['study', 'label']
    )
    
   
    joined_df = predicted_labeled_studies.join(valid_labeled_studies.set_index('study'), on='study')
    
    
    accuracy = accuracy_score(joined_df.label, joined_df.prediction)
    kappa = cohen_kappa_score(joined_df.label, joined_df.prediction)
    
    return joined_df, total_loss.item(), accuracy, kappa

In [None]:
best_models = train_model(model, dataloaders_dict['train'], criterion, optimizer, epoch_n, M, dataloaders_dict['valid'])

  0%|          | 0/9202 [00:00<?, ?it/s]



 10%|█         | 945/9202 [05:55<51:54,  2.65it/s] 

In [30]:
def save_models(best_models, prefix=''):
    """
    best_models: [(valid_acc, model_state_dict)]
    """
    paths = []
    for i, model in enumerate(best_models):
        path = 'models/'+prefix+'model_'+str(i)
        paths.append(path)
        # save model_state_dict
        torch.save(model, path)
    return paths

In [31]:
def load_models(paths):
    models = []
    for path in paths:
        m = inception_v3(pretrained=False, num_classes=1)
        m.load_state_dict(torch.load(path))
        models.append(m)
    
    return models

In [32]:
paths = save_models(best_models, prefix='snapshot_ensembling_')

In [33]:
models = load_models(paths)

In [34]:
_, valid_loss, valid_acc, valid_kappa = validate_model(
        models,
        dataloaders_dict['valid'],
        criterion,
        optimizer,
    )

print('epoch valid_loss: %8f' % valid_loss)
print('epoch valid_acc: %.8f' % valid_acc)
print('epoch valid_kappa: %.8f' % valid_kappa)

100%|██████████| 13/13 [00:07<00:00,  2.07it/s]

epoch valid_loss: 1.381551
epoch valid_acc: 0.61855670
epoch valid_kappa: 0.16164448



