## Setting

In [None]:
import import_ipynb
import os, timm, cv2 
import numpy as np
from glob import glob
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
from itertools import combinations
from model_load import *
from dataloader import *

In [None]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')

## Validation

In [None]:
class CustomDataset(Dataset):
    def __init__(self, transforms = None):
        
        self.path = glob('./valid_image/*/*')
        keys = list(set([x.split('/')[2] for x in self.path]))
        keys.sort()
        dictkeys = {key:idx for idx, key in enumerate(keys)}
        self.label = [dictkeys[i.split('/')[2]] for i in self.path]       
        self.transforms = transforms
        
    def __getitem__(self, index):
        
        path = self.path[index]
        data = cv2.imread(path, cv2.COLOR_BGR2RGB)
        
        if self.transforms is not None:
            data = self.transforms(data)
        
        return data, self.label[index]
    
    def __len__(self):
        return len(self.path)

In [None]:
size = 224
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize([size,size]),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [None]:
cifar10_train, cifar10_valid = cifar10(224, 32, workers = 4)
test_dataset = CustomDataset(transforms = transform)
test_loader = DataLoader(test_dataset, 128, shuffle = False, num_workers=4)

In [None]:
model = timm.create_model('resnet50', pretrained=False, num_classes=10)
checkpoint = torch.load(os.path.join('.', 'baseline.pth'))
model.load_state_dict(checkpoint)

In [None]:
model.eval()
model.to(device)
model_pred = torch.Tensor()
label = torch.Tensor()

with torch.no_grad():
    for img, ll in test_loader:
        pred = model(img.to(device)).cpu()
        pred = pred.argmax(dim = 1).cpu()
        model_pred = torch.cat([model_pred, pred])
        label = torch.cat([label, ll])

In [None]:
arr = confusion_matrix(label, model_pred)
cf_pd = pd.DataFrame(arr, index = keys, columns = keys)

plt.figure(figsize = (10,10))
sns.heatmap(cf_pd, annot = True, fmt = 'd', cmap = 'rocket_r', xticklabels = keys, yticklabels = keys)
plt.title('Confusion Matrix')

## Ensemble

In [None]:
def test(model, test_loader, name):
    model.eval()
    model.cuda()
    aa = torch.Tensor()
    bb = torch.Tensor()

    with torch.no_grad():
        for img, ll in test_loader:
            img, ll = img.cuda(), ll
            if name == 'seresnext101_32x4d':
                pred = model(img)[:,:10].cpu()
            else:
                pred = model(img).cpu()
            pred = F.softmax(pred)
            aa = torch.cat([aa, pred])
            bb = torch.cat([bb, ll])
  
    return aa.cpu(), bb.cpu()

In [None]:
model_name = ['seresnext101_32x4d', 'efficientnet_lite2', 'res2net50_26w_8s',
              'efficientnetv2_rw_s', 'efficientnet_b0']

In [None]:
model_pred = []
label = []
for i in tqdm(model_name):
    if i in ['resnet50', 'efficientnet_b0', 'efficientnetv2_rw_s']:
        model = model_load(i, pre = False, num=10)
        check = torch.load(os.path.join('./submit', i+'.pth'))['model_state_dict']
        model.load_state_dict(check)
    elif i in 'seresnext101_32x4d':
        model = model_load(i, pre = False, num=192)
        model.relu = nn.ReLU(True)
        model.fc2 = nn.Linear(192,10)
        check = torch.load(os.path.join('./submit', i+'.pth'))
        model.load_state_dict(check)
    else:
        model = model_load(i, pre = False, num=10)
        check = torch.load(os.path.join('./submit', i+'.pth'))
        model.load_state_dict(check)

    model_pred.append(test(model, test_loader,i)[0].tolist())
    label.append(test(model, test_loader,i)[1].tolist())

pred = torch.Tensor(model_pred)
label = torch.Tensor(label)[0]

pred = pred.sum(axis=0)
final_pred = pred.argmax(dim = 1)

print(confusion_matrix(final_pred, label))
print(accuracy_score(final_pred, label))

In [None]:
arr = confusion_matrix(label, model_pred)
cf_pd = pd.DataFrame(arr, index = keys, columns = keys)

plt.figure(figsize = (10,10))
sns.heatmap(cf_pd, annot = True, fmt = 'd', cmap = 'rocket_r', xticklabels = keys, yticklabels = keys)
plt.title('Ensemble Confusion Matrix')

## Find best combinations

In [None]:
acc = torch.zeros(6, 20)

for j in tqdm(range(len(model_name))):
    comb = list(combinations(model_name, j+1))
    for idx, k in enumerate(comb):
        model_pred = []
        label = []
        for i in k:
            if i in ['resnet50', 'efficientnet_b0', 'efficientnetv2_rw_s']:
                model = model_load(i, pre = False, num=10)
                check = torch.load(os.path.join('.', i+'.pth'))['model_state_dict']
                model.load_state_dict(check)
            elif i in 'seresnext101_32x4d':
                model = model_load(i, pre = False, num=192)
                model.relu = nn.ReLU(True)
                model.fc2 = nn.Linear(192,10)
                check = torch.load(os.path.join('.', i+'.pth'))
                model.load_state_dict(check)
            else:
                model = model_load(i, pre = False, num=10)
                check = torch.load(os.path.join('.', i+'.pth'))
                model.load_state_dict(check)

            model_pred.append(test(model, test_loader,i)[0].tolist())
            label.append(test(model, test_loader,i)[1].tolist())

        pred = torch.Tensor(model_pred)
        label = torch.Tensor(label)[0]

        pred = pred.sum(axis=0)
        final_pred = pred.argmax(dim = 1)

        print(confusion_matrix(final_pred, label))
        print(accuracy_score(final_pred, label))
        acc[j][idx] = accuracy_score(final_pred, label)

In [None]:
list(combinations(model_name, 5))

## How many times has the model been selected

In [None]:
dic = {name:0 for name in model_name}
for a,b in zip(acc.flatten().topk(10)[1]//acc.size()[1], acc.flatten().topk(10)[1]%acc.size()[1]):
    for m in comb[a][b]:
        dic[m] += 1