In [None]:
! pip install timm
! pip install torchsummary
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null 2>&1
#!python pytorch-xla-env-setup.py --version 20210331 --apt-packages libomp5 libopenblas-dev > /dev/null 2>&1

In [None]:
import pandas as pd
import numpy as np
import torch 
import torchvision
from timm import models
from torch import nn
import timm
from torchsummary import summary
import torch.nn.functional as F
import gc
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import cv2 as cv
from sklearn.model_selection import GroupKFold
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from collections import defaultdict
from torch.nn.parallel.data_parallel import data_parallel

In [None]:
#all_models = timm.list_models(pretrained=True)
#list(x for x in all_models if 'efficientnet' in x)

In [None]:
class MultiTask(nn.Module):
    def __init__(self, model_name='efficientnet_b5', num_classes=4, pretrained=True):
        super(MultiTask, self).__init__()
        base_model = timm.create_model(model_name, pretrained=pretrained, in_chans=3)
        self.bl_0 = nn.Sequential(
            base_model.conv_stem,
            base_model.bn1,
            base_model.act1
        )
        self.bl_1 = base_model.blocks[0]
        self.bl_2 = base_model.blocks[1]
        self.bl_3 = base_model.blocks[2]
        self.bl_4 = base_model.blocks[3]
        self.bl_5 = base_model.blocks[4]
        self.bl_6 = base_model.blocks[5]
        self.bl_7 = base_model.blocks[6]
        self.bl_8 = nn.Sequential(
            base_model.conv_head,
            base_model.bn2,
            base_model.act2
        )
        
        self.logit = nn.Linear(2560, num_classes)
        self.mask = nn.Sequential(
            nn.Conv2d(224, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, kernel_size=1, padding=0),
        )
        
    
    def forward(self, x):
        batch_size = len(x)
        x = self.bl_0(x)
        x = self.bl_1(x)
        x = self.bl_2(x)
        x = self.bl_3(x)
        x = self.bl_4(x)
        x = self.bl_5(x)
        
        mask = self.mask(x) # [batch_size, 1, 38, 38]
        #print(mask.shape)
        
        x = self.bl_6(x)
        x = self.bl_7(x)
        x = self.bl_8(x)
        x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)
        
        logit = self.logit(x)
        
        return logit, mask
        

In [None]:
class Config:
    augm = A.Compose([A.augmentations.transforms.HorizontalFlip(p=0.5),
                  A.augmentations.transforms.VerticalFlip(p=0.5),
                  A.augmentations.geometric.rotate.Rotate(p=0.5),
                  A.OneOf([
                       A.augmentations.transforms.Blur(),
                       A.augmentations.transforms.GlassBlur(),
                       A.augmentations.transforms.GaussianBlur(),
                       A.augmentations.transforms.GaussNoise(),
                       A.augmentations.transforms.RandomGamma(),
                       A.augmentations.transforms.InvertImg(),
                       #A.augmentations.transforms.RandomFog()
                   ], p=0.5)])
    train_path = '../input/covid19-detection-890pxpng-study/train/' 
    test_path = '../input/covid19-detection-890pxpng-study/test/'
    mask_path = '../input/covid19-detection-890pxpng-study/ROI Mask/'
    

In [None]:
def image_process(image_id, mask=False, train=True):
    path = Config.mask_path if mask else Config.train_path
    image_path = path + image_id + '.png'
    if mask:
        image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
        if image is None:
            path = Config.train_path + image_id + '.png'
            image = cv.imread(path, cv.IMREAD_GRAYSCALE)
        image = cv.resize(image, (38, 38))
        
    else:
        image = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
        image = cv.resize(image, (600, 600))
    if train:
        image = Config.augm(image=image)['image']
    return image

In [None]:
class CreateDataset(Dataset):
    def __init__(self, df, train=True):
        super(CreateDataset, self).__init__()
        self.df = df
        self.label_cols = df.columns[4:8]
        self.train = train
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, indx):
        image_id = df.loc[indx, 'id']
        label = df.loc[indx, self.label_cols].values
        
        image = image_process(image_id, train=self.train) / 255.
        mask = image_process(image_id, mask=True, train=self.train) / 255.
        
        output = {'image': image,
                  'mask': mask,
                  'label': label}
        
        return output

In [None]:
def collate(batch):
    collate = defaultdict(list)
    
    for i in batch:
        for k, v in i.items():
            collate[k].append(v)
            
    batch_size = len(batch)
    label = np.ascontiguousarray(np.stack(collate['label'])).astype(np.float32)
    collate['label'] = torch.from_numpy(label)
    
    image = np.stack(collate['image'])
    image = image.reshape(batch_size, 1, 600, 600).repeat(3,1)
    image = np.ascontiguousarray(image)
    collate['image'] = torch.from_numpy(image)
    
    mask = np.stack(collate['mask'])
    mask = mask.reshape(batch_size, 1, 38, 38)#.repeat(3, 1)
    mask = np.ascontiguousarray(mask)
    collate['mask'] = torch.from_numpy(mask)
    
    return collate   

In [None]:
df = pd.read_csv('../input/covid19-detection-890pxpng-study/train.csv')
df['fold'] = -1

In [None]:
groupkfold = GroupKFold(n_splits=5)
for fold, (train_indx, valid_indx) in enumerate(groupkfold.split(df, groups=df.id.tolist())):
    df.loc[valid_indx, 'fold'] = fold

In [None]:
def do_valid(net, valid_loader):

    valid_probability = []
    valid_truth = []
    valid_num = 0

    net.eval()
    start_timer = timer()
    for t, batch in enumerate(valid_loader):
        batch_size = len(batch['index'])
        image = batch['image'].cuda()
        onehot = batch['onehot']
        label = onehot.argmax(-1)

        with torch.no_grad():
                logit, mask = data_parallel(net,image)
                probability = F.softmax(logit,-1)

        valid_num += batch_size
        valid_probability.append(probability.data.cpu().numpy())
        valid_truth.append(label.data.cpu().numpy())
        #print('\r %8d / %d  %s'%(valid_num, len(valid_loader.dataset),time_to_str(timer() - start_timer,'sec')),end='',flush=True)

    truth = np.concatenate(valid_truth)
    probability = np.concatenate(valid_probability)
    predict = probability.argsort(-1)[::-1]

    loss = np_loss_cross_entropy(probability,truth)
    topk = (predict==truth.reshape(-1,1))
    acc  = topk[:, 0]
    topk = topk.mean(0).cumsum()
    acc = [acc[truth==i].mean() for i in range(num_study_label)]

    return [loss, topk[0], topk[1]]


In [None]:
for i in range(5):
    valid_df = df[df.fold == i]
    train_df = df[df.fold != i]
    
    train_dataset = CreateDataset(train_df, train=True)
    valid_dataset = CreateDataset(valid_df, train=False)
    
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    
    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=8,
        drop_last=True,
        num_workers=4,
        pin_memory=True,
        worker_init_fn=lambda id_: np.random.seed(torch.initial_seed() // 2 ** 32 + id_),
        collate_fn=collate
                             )
    
    valid_loader = DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        batch_size=16,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate
                             )
    
    
    net = MultiTask().cuda()
    
    valid_loss = np.zeros(4,np.float32)
    train_loss = np.zeros(3,np.float32)
    batch_loss = np.zeros_like(train_loss)
    sum_train_loss = np.zeros_like(train_loss)
    sum_train = 0
    loss0 = torch.FloatTensor([0]).cuda().sum()
    loss1 = torch.FloatTensor([0]).cuda().sum()
    loss2 = torch.FloatTensor([0]).cuda().sum()
    
    iteration = 0
    epoch = 0
    rate = 0
    while  iteration < 20:

        for t, batch in enumerate(train_loader):

            '''if iteration in iter_save:
                if iteration != start_iteration:
                    torch.save({
                        'state_dict': net.state_dict(),
                        'iteration': iteration,
                        'epoch': epoch,
                    }, out_dir + '/checkpoint/%08d_model.pth' % (iteration))
                    pass

            if (iteration % iter_valid == 0):
                    valid_loss = do_valid(net, valid_loader)  #
                    pass

            if (iteration % iter_log == 0):
                print('\r', end='', flush=True)
                log.write(message(mode='log') + '\n')'''


            def get_learning_rate(optimizer):
                lr=[]
                for param_group in optimizer.param_groups:
                    lr +=[ param_group['lr'] ]
                lr = lr[0]

                return lr
            
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()))
            rate = get_learning_rate(optimizer)

            batch_size = 8
            image = batch['image'].cuda()
            truth_mask = batch['mask'].cuda()
            #truth_mask = F.interpolate(truth_mask, size=(38,38), mode='bilinear', align_corners=False)
            onehot = batch['label'].cuda()
            label = onehot.argmax(-1)

            net.train()
            optimizer.zero_grad()

        
            print('fp32')
            logit, mask = data_parallel(net, image)
            loss0 = F.cross_entropy(logit, label)
            loss1 = F.binary_cross_entropy_with_logits(mask, truth_mask)

            (loss0 + loss1).backward()
            optimizer.step()

            epoch += 1 / len(train_loader)
            iteration += 1

            batch_loss = np.array([loss0.item(), loss1.item(), loss2.item()])
            sum_train_loss += batch_loss
            sum_train += 1
            if iteration % 100 == 0:
                train_loss = sum_train_loss / (sum_train + 1e-12)
                sum_train_loss[...] = 0
                sum_train = 0

            print('\r', end='', flush=True)
            print(message(mode='print'), end='', flush=True)


        #log.write('\n')

    
   