# Train model

In [23]:
# necessary libraries for training the model 
import torch
import torch.nn as nn
import torchvision 
from torchvision import transforms
import torchmetrics
import pytorch_lightning as pl # handle the complete traning phrase
from pytorch_lightning.callbacks import ModelCheckpoint # frequently stop the current weight
from pytorch_lightning.loggers import TensorBoardLogger # enable to login the tensor board
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

## Image data Prepration

In [24]:
# load the file path
def load_file(path):
    return np.load(path).astype(np.float32)

In [25]:
# data augmentation and transforms  
train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(0.0021,0.0010),
                                      transforms.RandomAffine(degrees=(-5,5), translate=(0,0.05), scale=(0.9,1.1))])

val_transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize(0.0021, 0.0010)])

In [26]:
## create train and validation set
train_set = torchvision.datasets.DatasetFolder(root="Outliers/train",
                                               loader=load_file, extensions="npy", transform=train_transform)

val_set = torchvision.datasets.DatasetFolder(root="Outliers/val",
                                             loader=load_file, extensions="npy", transform=val_transform)

In [29]:
for i in range(len(val_set)):
    if val_set[i][0].shape != torch.Size([1, 224, 224]):
        print(f"index: {i}  shape:{val_set[i][0].shape}")
print("Done!")

Done!


In [65]:
# labels for each diseases
train_set.class_to_idx

{'Atelectasis': 0,
 'Cardiomegaly': 1,
 'Consolidation': 2,
 'Edema': 3,
 'Effusion': 4,
 'Emphysema': 5,
 'Fibrosis': 6,
 'Hernia': 7,
 'Infiltration': 8,
 'Mass': 9,
 'No Finding': 10,
 'Nodule': 11,
 'Pleural_Thickening': 12,
 'Pneumothorax': 13}

In [80]:
# create the data loader
batch_size = 64
nums_worker = 4

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=nums_worker, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, num_workers=nums_worker, shuffle=False)

## Utils functions

In [None]:
### deep generalized max-pooling
'''Adjustment version from V. Christlein, L. Spranger, M. Seuret, A. Nicolaou, P. Král, A. Maier. "Deep Generalized Max Pooling." arXiv preprint arXiv:1908.05040 (2019).'''

class GMP(nn.Module):
    def __init__(self, lamb=10**3): 
        super(GMP, self).__init__()
        
        self.lamb = nn.Parameter(lamb * torch.ones(1).cuda())

    def forward(self, x):
        B, D, H, W = x.shape
        N = H * W
        identity = torch.eye(N).cuda()
        # reshape x, s.t. we can use the gmp formulation as a global pooling operation
        x = x.view(B, D, N)
        x = x.permute(0, 2, 1)
        # compute the linear kernel
        K = torch.bmm(x, x.permute(0, 2, 1))
        # solve the linear system (K + lambda * I) * alpha = ones
        A = K + self.lamb * identity
        o = torch.ones(B, N, 1).cuda()
        alphas = torch.linalg.solve(A, o)
        alphas = alphas.view(B, 1, -1)        
        xi = torch.bmm(alphas, x)
        xi = xi.view(B, -1)
        # L2 normalization
        xi = nn.functional.normalize(xi)
        
        return xi

In [None]:
import torch.nn.functional as F
## Focal Loss function 
class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, device,weight=torch.tensor([1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.5,1.0,1.0,1.0]), gamma=2,reduction='mean'): # None
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.device = device
        self.weight = weight.to(self.device)

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight) 
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [None]:
### Channel attention (SENet)
'''Moskomule, 2019, https://github.com/moskomule/senet.pytorch'''

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [None]:
### spatial attention

class Spatial_Attention_Module(nn.Module):
    def __init__(self, k: int): # in paper best is k = 7
        super(Spatial_Attention_Module, self).__init__()
        self.avg_pooling = torch.mean
        self.max_pooling = torch.max
        # In order to keep the size of the front and rear images consistent
        assert k in [3, 5, 7], "kernel size = 1 + 2 * padding, so kernel size must be 3, 5, 7"
        self.conv = nn.Conv2d(2, 1, kernel_size = (k, k), stride = (1, 1), padding = ((k - 1) // 2, (k - 1) // 2),
                              bias = False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # compress the C channel to 1 and keep the dimensions
        avg_x = self.avg_pooling(x, dim = 1, keepdim = True)
        max_x, _ = self.max_pooling(x, dim = 1, keepdim = True)
        v = self.conv(torch.cat((max_x, avg_x), dim = 1))
        v = self.sigmoid(v)
        return x * v

In [None]:
### coordinate attention 
'''The following codes are from the paper, Coordinate Attention for Efficient Mobile Network Design
Hou et al., 2021, https://github.com/houqb/CoordAttention'''

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

In [None]:
## define the model to generate multi-head attention map
class multi_att(nn.Module):
    def __init__(self, inchannels=1, se_input=1, spatial_input=7, 
                 coord_input=1, coord_output=1, extract=False, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 
                outchannels=1):
        super(multi_att, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=inchannels, out_channels=1, kernel_size=1, stride=1, padding=0)
#         self.conv3 = torch.nn.Conv1d(in_channels=inchannels, out_channels=1, kernel_size=1, stride=1, padding=0)
        self.seatt = SELayer(se_input)
        self.spatialatt = Spatial_Attention_Module(spatial_input)
        self.sigmoid = torch.nn.Sigmoid()
        self.coordatt = CoordAtt(coord_input,coord_output)
        
        self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        x = self.conv1(x)
        concate_data = torch.concat((self.seatt(x), self.spatialatt(x)), dim=1)
        final_attention = torch.concat((self.sigmoid(concate_data), self.coordatt(x)), dim=1)
        final_attention = self.conv2(final_attention)
        
        return final_attention

In [None]:
def split_kfold(dataset, k_fold):
    '''
    Arguments:
    dataset: the full dataset
    k_fold: the number of cross-validation fold wants to split
    
    return:
    train_list: all possible training set after split
    test_list: all possible testing set after split
    '''
    # the total dataset 
    total_size = len(dataset)
    # the proportion of the testing data set
    prop = 1/k_fold
    ## containing size for the validation each time (need to be integer)
    vali_size = torch.round(torch.tensor(total_size * prop)) 
    vali_size = vali_size.to(torch.int)
  
    # starting split the test and train data sets
    train_list = []
    vali_list = []

    for i in range(k_fold):
        
    ## splitting vali and train
        ### get the splitting indices for training set 
        train_left = list(range(0,i*vali_size))
        train_right = list(range(i*vali_size + vali_size, total_size))
        train_indices = train_left + train_right
        ### get the splitting indices for testing set
        vali_indices = list(range(i*vali_size, i*vali_size + vali_size))
    ## split the test and train data sets
        train_set = torch.utils.data.dataset.Subset(dataset,train_indices)
        vali_set = torch.utils.data.dataset.Subset(dataset,vali_indices)
        print("The length of the training set is {}".format(len(train_set)))
        print("The length of the training set is {}".format(len(vali_set)))
        train_list.append(train_set)
        vali_list.append(vali_set)

    return train_list, vali_list

In [None]:
## Model Training

In [None]:
## create the training model (deep generalized max-pooling version) 
class GMP_Model(pl.LightningModule):
    
    def __init__(self):
        super(GMP_Model, self).__init__()
        
        ## initialize the model
        self.model = torchvision.models.resnet101(pretrained=True)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # convert to 1 channel
        # change self.model.avgpool = GMP()
        self.model.fc = torch.nn.Linear(2048, 14)
        
        ## get the feature map before the prediction layer
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        ## model initialization and metrics
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = FocalLoss()
        
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.train_sep = torchmetrics.Specificity()#average='macro' #num_classes=14)
        self.val_sep = torchmetrics.Specificity()#average='macro', num_classes=14)
        self.train_f1 = torchmetrics.F1Score()
        self.val_f1 = torchmetrics.F1Score()
        #self.train_auc = torchmetrics.AUC(reorder=True)
        #self.val_auc = torchmetrics.AUC(reorder=True)
        
    def forward(self, data):
        feature_map = self.feature_map(data)
        gmp = GMP()
        max_outs = gmp(feature_map) # change the avgpool into genralized max pooling 
        final_outs = self.model.fc(max_outs) # torch.flatten()
        
        return final_outs
    
    # training process
    def training_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.long()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label) # compute loss
        
        # log (record the loss) log loss and batch accuracy
        self.log("Train loss", loss)
        self.log("Step Train Acc", self.train_acc(pred, label.int())) # torch sigmoid
        self.log("Step Train Sep", self.train_sep(pred, label.int()))
        self.log("Step Train F1", self.train_f1(pred, label.int()))
        #self.log("Step Train AUC", self.train_auc(pred, label.int()))
        
        return loss
    
    # compute the whole training set's metrics
    def training_epoch_end(self, outs):
        self.log("Train ACC", self.train_acc.compute())
        self.log("Train Sep", self.train_sep.compute())
        self.log("Train F1", self.train_f1.compute())
        #self.log("Train AUC", self.train_auc.compute())
        
    # validation process
    def validation_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.long()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label)
        
        self.log("Val loss", loss)
        self.log("Step Val Acc", self.val_acc(pred, label.int())) # torch sigmoid
        self.log("Step Val Sep", self.val_sep(pred, label.int()))
        self.log("Step Val F1", self.val_f1(pred, label.int()))
        #self.log("Step Val AUC", self.val_auc(pred, label.int()))
        
    # compute the whole validation's metrics
    def validation_epoch_end(self, outs):
        self.log("Val ACC", self.val_acc.compute())
        self.log("Val Sep", self.val_sep.compute())
        self.log("Val F1", self.val_f1.compute())
        #self.log("Val AUC", self.val_auc.compute())
        
        
    # return the list of all optimizers
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
# setting the training process
## creat the checkout callback to store checkpoints during training, store 10 best models based on validation accuracy)
gmp = GMP_Model()
checkpoint_callback = ModelCheckpoint(monitor="Val F1",save_top_k=10, mode="max")
trainer = pl.Trainer(logger=TensorBoardLogger(save_dir="./logs_gmpbase"), log_every_n_steps=1, 
                    callbacks=checkpoint_callback, 
                    max_epochs=30, gpus=1)
trainer.fit(gmp, train_loader, val_loader)

In [None]:
## create the training model 
class Self_multiatt_Model(pl.LightningModule):
    
    def __init__(self, inchannels, extract=False):
        super(Self_multiatt_Model, self).__init__()
        
        ## initialize the model
        self.model = torchvision.models.resnet101(pretrained=True)
        self.multiatt = multi_att(inchannels=inchannels, extract=extract)
        
        ## change fully connected layer
        self.model.fc = torch.nn.Linear(3, 14)
        
        ## get the feature map before the prediction layer
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        ## model initialization and metrics
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = FocalLoss()
        
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.train_sep = torchmetrics.Specificity()#average='macro' #num_classes=14)
        self.val_sep = torchmetrics.Specificity()#average='macro', num_classes=14)
        self.train_f1 = torchmetrics.F1Score()
        self.val_f1 = torchmetrics.F1Score()
        #self.train_auc = torchmetrics.AUC(reorder=True)
        #self.val_auc = torchmetrics.AUC(reorder=True)
        
    def forward(self, data):
        attention_map = self.multiatt(data)
        feature_map = self.feature_map(fattention_map)
        gmp = GMP()
        max_outs = gmp(feature_map) # change the avgpool into genralized max pooling 
        final_outs = self.model.fc(max_outs) # torch.flatten()
        
        return final_outs
    
    # training process
    def training_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.long()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label) # compute loss
        
        # log (record the loss) log loss and batch accuracy
        self.log("Train loss", loss)
        self.log("Step Train Acc", self.train_acc(pred, label.int())) # torch sigmoid
        self.log("Step Train Sep", self.train_sep(pred, label.int()))
        self.log("Step Train F1", self.train_f1(pred, label.int()))
        #self.log("Step Train AUC", self.train_auc(pred, label.int()))
        
        return loss
    
    # compute the whole training set's metrics
    def training_epoch_end(self, outs):
        self.log("Train ACC", self.train_acc.compute())
        self.log("Train Sep", self.train_sep.compute())
        self.log("Train F1", self.train_f1.compute())
        #self.log("Train AUC", self.train_auc.compute())
        
    # validation process
    def validation_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.long()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label)
        
        self.log("Val loss", loss)
        self.log("Step Val Acc", self.val_acc(pred, label.int())) # torch sigmoid
        self.log("Step Val Sep", self.val_sep(pred, label.int()))
        self.log("Step Val F1", self.val_f1(pred, label.int()))
        #self.log("Step Val AUC", self.val_auc(pred, label.int()))
        
    # compute the whole validation's metrics
    def validation_epoch_end(self, outs):
        self.log("Val ACC", self.val_acc.compute())
        self.log("Val Sep", self.val_sep.compute())
        self.log("Val F1", self.val_f1.compute())
        #self.log("Val AUC", self.val_auc.compute())
        
        
    # return the list of all optimizers
    def configure_optimizers(self):

In [None]:
# setting the training process
## creat the checkout callback to store checkpoints during training, store 10 best models based on validation accuracy)
Self_multiatt = Self_multiatt_Model(1)
checkpoint_callback = ModelCheckpoint(
     monitor="Val F1",
     save_top_k=10, 
     mode="max")

# set the trainer
gpus=1 
trainer = pl.Trainer(logger=TensorBoardLogger(save_dir="./logs_self_multiatt"), log_every_n_steps=1, 
                    callbacks=checkpoint_callback, 
                    max_epochs=30, gpus=gpus) # log every batch
trainer.fit(Self_multiatt, train_loader, val_loader)

In [None]:
## create the training model for our proposed: add multi-head attention in the end
class Multiatt_proposed(pl.LightningModule):
    
    def __init__(self,inchannels=2048, extract=False):
        super(Multiatt_proposed, self).__init__()
        
        ## initialize the model
        self.model = torchvision.models.resnet101(pretrained=True)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # convert to 1 channel
        self.multiatt = multi_att(inchannels=inchannels, extract=extract)

        self.model.fc = torch.nn.Linear(1, 14)
        
        ## get the feature map before the prediction layer
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        ## model initialization and metrics
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = FocalLoss()
        
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.train_sep = torchmetrics.Specificity()#average='macro' #num_classes=14)
        self.val_sep = torchmetrics.Specificity()#average='macro', num_classes=14)
        self.train_f1 = torchmetrics.F1Score()
        self.val_f1 = torchmetrics.F1Score()
        #self.train_auc = torchmetrics.AUC(reorder=True)
        #self.val_auc = torchmetrics.AUC(reorder=True)
        
    def forward(self, data):
        feature_map = self.feature_map(data) # output feature 
        multi_attention = self.multiatt(feature_map) # use multi-head attention
        gmp = GMP()
        max_outs = gmp(multi_attention) # change the avgpool into genralized max pooling 
        final_outs = self.model.fc(max_outs) # torch.flatten()
        
        return final_outs
    
    # training process
    def training_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.long()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label) # compute loss
        
        # log (record the loss) log loss and batch accuracy
        self.log("Train loss", loss)
        self.log("Step Train Acc", self.train_acc(pred, label.int())) # torch sigmoid
        self.log("Step Train Sep", self.train_sep(pred, label.int()))
        self.log("Step Train F1", self.train_f1(pred, label.int()))
        #self.log("Step Train AUC", self.train_auc(pred, label.int()))
        
        return loss
    
    # compute the whole training set's metrics
    def training_epoch_end(self, outs):
        self.log("Train ACC", self.train_acc.compute())
        self.log("Train Sep", self.train_sep.compute())
        self.log("Train F1", self.train_f1.compute())
        #self.log("Train AUC", self.train_auc.compute())
        
    # validation process
    def validation_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.long()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label)
        
        self.log("Val loss", loss)
        self.log("Step Val Acc", self.val_acc(pred, label.int())) # torch sigmoid
        self.log("Step Val Sep", self.val_sep(pred, label.int()))
        self.log("Step Val F1", self.val_f1(pred, label.int()))
        #self.log("Step Val AUC", self.val_auc(pred, label.int()))
        
    # compute the whole validation's metrics
    def validation_epoch_end(self, outs):
        self.log("Val ACC", self.val_acc.compute())
        self.log("Val Sep", self.val_sep.compute())
        self.log("Val F1", self.val_f1.compute())
        #self.log("Val AUC", self.val_auc.compute())
        
        
    # return the list of all optimizers
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
# setting the training process
## creat the checkout callback to store checkpoints during training, store 10 best models based on validation accuracy)
multiatt_prop = Multiatt_proposed(2048)

# set the trainer
gpus=1

trainer = pl.Trainer(logger=TensorBoardLogger(save_dir="./logs_proposed_multiatt"), log_every_n_steps=1, 
                    callbacks=checkpoint_callback, 
                    max_epochs=6, gpus=gpus) # log every batch
trainer.fit(multiatt_prop, train_loader, val_loader)