# for Model ensemble inference

## Environment
### System 1.
i5-13600KF
24 GB Memory
RTX 3060 with 12GB VRAM 

### System 2.
i7-8700K
16 GB Memory
RTX 3090 with 24GB VRAM

## Models
**3x** vit_base_resnet50d_224
**1x** xception41p


In [8]:
from tqdm import tqdm

import pandas as pd


from torch.utils.data import DataLoader
from torchvision.transforms import AutoAugment
from torchvision.transforms import AutoAugmentPolicy

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm

from utils.utils import *
from utils.MyModel import *

## Define functions

In [11]:
def getTransform(train_mean = [.5, .5, .5], train_std= [.5, .5, .5], val_mean= [.5, .5, .5], val_std= [.5, .5, .5], aug_mode='albu'):
    if aug_mode == 'albu':
        transform_train = A.Compose([
            A.Resize(32,32),
            A.Rotate(limit=(-360,360), interpolation=1, border_mode=1, always_apply=True),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.Normalize(train_mean, train_std),
            ToTensorV2(),
        ])
    
        transform_test = A.Compose([
            A.Resize(32,32),
            A.Normalize(val_mean, val_std),
            ToTensorV2(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.Resize(32),
            AutoAugment(AutoAugmentPolicy.CIFAR10),
            transforms.Normalize(train_mean, train_std),
            transforms.ToTensor(),
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize(32),
            transforms.Normalize(val_mean, val_std),
            transforms.ToTensor(),
        ])
        
    return transform_train, transform_test

## Load Data

In [12]:
training_set, test_set = getDataSet('../dataset/')
set_random_seed(1813)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
test_norm, test_std = getNormStd(training_set)

In [15]:
_, test_transform = getTransform(test_norm, test_std, test_norm, test_std, 'albu')

## Inference

inference on System 1.
use 10GB of VRAM with [batch_size=16]

In [16]:
models = []

test_dataset = CustomLoader(dataset_list=test_set, transforms=test_transform, is_train=False, aug_mode='albu')
test_dataloader = DataLoader(dataset=test_dataset,batch_size=16, shuffle=False, num_workers=4)

for i in range(5): 
    model = timm.create_model(model_name='vit_base_resnet50d_224', pretrained=False, num_classes=30, img_size=32)
    model.load_state_dict(torch.load(f'../weights/vit_base_resnet50d_224_cutMix_2/best_model_fold{i}.pth'))
    model.to(device)
    models.append(model)
    
for i in range(5):
    model = timm.create_model(model_name='vit_base_resnet50d_224', pretrained=False, num_classes=30, img_size=32)
    model.load_state_dict(torch.load(f'../weights/vit_base_resnet50d_224_cutMix_0.5838/best_model_fold{i}.pth'))
    model.to(device)
    models.append(model)

for i in range(5):
    model = timm.create_model(model_name='vit_base_resnet50d_224', pretrained=False, num_classes=30, img_size=32)
    model.load_state_dict(torch.load(f'../weights/vit_base_resnet50_cos/best_model_fold{i}.pth'))
    model.to(device)
    models.append(model)

for i in range(5):
    model = timm.create_model(model_name='xception41p', pretrained=False, num_classes=30)
    model.load_state_dict(torch.load(f'../weights/xception41p/best_model_fold{i}.pth'))
    model.to(device)
    models.append(model)

for i in range(20):
    models[i].eval()

id_list = []
pred_list = []

with torch.no_grad():
    for image, file_name in tqdm(test_dataloader):
        image = image.to(device)
        
        logits = torch.zeros([len(image),30]).to(device)
        
        for i in range(20):
            logits += F.softmax(models[i](image), dim=1) # Soft voting
            
        pred_list += logits.argmax(1).detach().cpu().numpy().tolist()
        id_list += file_name

100%|██████████| 188/188 [00:32<00:00,  5.87it/s]


## Save result to .csv file

In [17]:
column_name = ['id','label']
data = {'id':id_list,'label':pred_list}
data_df = pd.DataFrame(data)

data_df.to_csv('./final_result_replay.csv',sep=',',index=False)