In [None]:
import pandas as pd
from glob import glob
from tqdm import tqdm

import timm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from model import MyModel
from data import get_valid_transforms, MyDataset

device = torch.device("cuda" if torch.cuda.is_available()  else "cpu")

print("device:", device)

device: cuda


In [None]:
class MyModel_drop(nn.Module): # model from team member
    def __init__(self, n_classes, model_name):
        super(MyModel_drop, self).__init__()
        self.feature = timm.create_model(model_name, pretrained=False)

        self.out_features = self.feature.fc.in_features
        self.feature.fc = nn.Linear(in_features=self.out_features, out_features=self.out_features//4, bias=True) 
        self.out = nn.Linear(in_features=self.out_features//4, out_features=n_classes, bias=True)
        self.drop = nn.Dropout(0.5)

    def forward(self, x):
        x = self.feature(x)
        x = self.drop(x)
        x = self.out(x)
        return x    

In [None]:
densenet_paths = [i for i in glob('./saved/densenetblur121d/*')]
resnet50_paths = [i for i in glob('./saved/gluon_seresnext50_32x4d/*')]
inception_paths = [i for i in glob('./saved/inception_v3/*')]
resnet26_paths = [i for i in glob('./saved/seresnext26d_32x4d/*')]
ensemble_paths = resnet50_paths + densenet_paths + resnet26_paths

In [None]:
ensemble_paths

['./saved/gluon_seresnext50_32x4d/seresnext50_seed7_1_best_acc_113_0.9231.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_2_best_acc_129_0.9077.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_seed7_2_best_acc_108_0.9308.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_4_best_acc_96_0.9538.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_1_best_acc_102_0.9462.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_seed7_4_best_acc_119_0.9231.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_3_best_acc_82_0.9308.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_0_best_acc_91_0.9538.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_seed7_0_best_acc_113_0.9308.pt',
 './saved/gluon_seresnext50_32x4d/seresnext50_seed7_3_best_acc_127_0.9308.pt',
 './saved/densenetblur121d/densenetblur121d_seed7_3_best_acc_74_0.9538.pt',
 './saved/densenetblur121d/densenetblur121d_4_best_acc_88_0.9615.pt',
 './saved/densenetblur121d/densenetblur121d_1_best_acc_82_0.9231.pt',
 './saved/densenetblur121d/

In [None]:
# model names
densenet = 'densenetblur121d'
resnet50 = 'gluon_seresnext50_32x4d'
inception = 'gluon_inception_v3'
resnet26 = 'seresnext26d_32x4d'

In [None]:
df_test = pd.read_csv('./data/sample_submission.csv')
submission = df_test.copy()
test_transforms = get_valid_transforms()
test_dataset = MyDataset(df_test.values, test_transforms, color=None, root='./data/test')

test_loader = DataLoader(
    test_dataset, 
    batch_size=1,
    num_workers=0,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
)

In [None]:
def get_stack_logits(model_paths, test_loader, df_test):
    stack_logits = torch.zeros(len(df_test), 2).cpu()
    with torch.no_grad():
        for i in tqdm(model_paths):
            print(i)
            
            if 'densenetblur' in i: model = MyModel(2, densenet).to(device)
            elif 'resnext50' in i: model = MyModel(2, resnet50).to(device)
            elif 'inception' in i: model = MyModel(2, inception).to(device)
            elif 'resnext26d' in i: model = MyModel_drop(2, resnet26).to(device)
            
            temp = []
            checkpoint = torch.load(i)
            model.load_state_dict(checkpoint['model'])
            model.eval()
            
            for x, y in test_loader:
                x, y = x.to(device).float(), y.to(device).long()
                outs = model(x)
                outs = outs.detach().cpu()
                temp.append(outs)
            res = torch.stack(temp, dim=0).squeeze()
            print(res.shape)
            stack_logits += res
    return stack_logits

In [None]:
stack_logits = get_stack_logits(ensemble_paths, test_loader, df_test)

  0% 0/25 [00:00<?, ?it/s]

./saved/gluon_seresnext50_32x4d/seresnext50_seed7_1_best_acc_113_0.9231.pt


  4% 1/25 [00:08<03:15,  8.13s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_2_best_acc_129_0.9077.pt


  8% 2/25 [00:16<03:06,  8.12s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_seed7_2_best_acc_108_0.9308.pt


 12% 3/25 [00:24<02:59,  8.18s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_4_best_acc_96_0.9538.pt


 16% 4/25 [00:32<02:51,  8.18s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_1_best_acc_102_0.9462.pt


 20% 5/25 [00:40<02:41,  8.09s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_seed7_4_best_acc_119_0.9231.pt


 24% 6/25 [00:48<02:33,  8.07s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_3_best_acc_82_0.9308.pt


 28% 7/25 [00:57<02:27,  8.19s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_0_best_acc_91_0.9538.pt


 32% 8/25 [01:05<02:18,  8.17s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_seed7_0_best_acc_113_0.9308.pt


 36% 9/25 [01:13<02:09,  8.11s/it]

torch.Size([100, 2])
./saved/gluon_seresnext50_32x4d/seresnext50_seed7_3_best_acc_127_0.9308.pt


 40% 10/25 [01:21<02:02,  8.17s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_seed7_3_best_acc_74_0.9538.pt


 44% 11/25 [01:29<01:52,  8.00s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_4_best_acc_88_0.9615.pt


 48% 12/25 [01:37<01:43,  7.97s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_1_best_acc_82_0.9231.pt


 52% 13/25 [01:45<01:35,  7.98s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_seed7_0_best_acc_57_0.9385.pt


 56% 14/25 [01:52<01:27,  7.97s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_3_best_acc_76_0.9462.pt


 60% 15/25 [02:01<01:21,  8.12s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_seed7_4_best_acc_72_0.9538.pt


 64% 16/25 [02:09<01:13,  8.12s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_seed7_1_best_acc_75_0.9154.pt


 68% 17/25 [02:18<01:05,  8.25s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_seed7_2_best_acc_47_0.9538.pt


 72% 18/25 [02:26<00:57,  8.24s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_0_best_acc_70_0.9538.pt


 76% 19/25 [02:34<00:48,  8.11s/it]

torch.Size([100, 2])
./saved/densenetblur121d/densenetblur121d_2_best_acc_96_0.9077.pt


 80% 20/25 [02:42<00:41,  8.26s/it]

torch.Size([100, 2])
./saved/seresnext26d_32x4d/4_0.9846_seresnext26d_32x4d_2Aug.pth


 84% 21/25 [02:50<00:32,  8.18s/it]

torch.Size([100, 2])
./saved/seresnext26d_32x4d/0_0.9692_seresnext26d_32x4d_2Aug.pth


 88% 22/25 [02:58<00:24,  8.01s/it]

torch.Size([100, 2])
./saved/seresnext26d_32x4d/3_0.9538_seresnext26d_32x4d_2Aug.pth


 92% 23/25 [03:06<00:15,  7.91s/it]

torch.Size([100, 2])
./saved/seresnext26d_32x4d/1_0.9385_seresnext26d_32x4d_2Aug.pth


 96% 24/25 [03:14<00:07,  7.97s/it]

torch.Size([100, 2])
./saved/seresnext26d_32x4d/2_0.9154_seresnext26d_32x4d_2Aug.pth


100% 25/25 [03:22<00:00,  8.08s/it]

torch.Size([100, 2])





In [None]:
def make_submission(logits, submission, file_name):
    pred = logits.argmax(-1).cpu().numpy()
    submission['COVID'] = pred
    submission.to_csv(f'./submissions/_{file_name}.csv', index=False)
    check = pd.read_csv(f'./submissions/_{file_name}.csv')
    ratio = check['COVID'].sum() / len(check)
    return ratio, check

In [None]:
save_file_name = 'ensemble_submission'
make_submission(stack_logits, submission, save_file_name)

(0.48,
    file_name  COVID
 0      0.png      0
 1      1.png      0
 2      2.png      0
 3      3.png      1
 4      4.png      0
 5      5.png      1
 6      6.png      1
 7      7.png      1
 8      8.png      0
 9      9.png      1
 10    10.png      0
 11    11.png      0
 12    12.png      0
 13    13.png      0
 14    14.png      1
 15    15.png      0
 16    16.png      1
 17    17.png      0
 18    18.png      0
 19    19.png      0
 20    20.png      0
 21    21.png      0
 22    22.png      0
 23    23.png      0
 24    24.png      1
 25    25.png      0
 26    26.png      0
 27    27.png      0
 28    28.png      1
 29    29.png      1
 30    30.png      1
 31    31.png      1
 32    32.png      1
 33    33.png      0
 34    34.png      0
 35    35.png      0
 36    36.png      0
 37    37.png      1
 38    38.png      1
 39    39.png      0
 40    40.png      1
 41    41.png      1
 42    42.png      0
 43    43.png      1
 44    44.png      1
 45    45.png      1
 46   