## Import Libs

In [None]:
import os
import warnings # 避免一些可以忽略的报错
warnings.filterwarnings('ignore')
import random
import gc
import copy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm # 进度条
import time

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import timm

import albumentations as A # 数据增强库
from albumentations.pytorch import ToTensorV2

## CONFIG

In [None]:
is_debug = False

class CONFIG:
    seed = 308

    test_batch_size = 512
    img_size = [28, 28]
    
    n_classes = 10

    n_workers = 1
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model_name = "tf_efficientnetv2_s.in21k_ft_in1k"
    pool_name = "GeMPool"
    timm_pretrained = False # 在 kaggle 提交的 notebook 中不允许联网，所以无法下载预训练权重所以设为 False，如果为 True 会internet报错
    
    test_csv = "/kaggle/input/digit-recognizer/test.csv"
    ckpt_path = "/kaggle/input/308-digit-recognizer-baseline2-swa/swa.pth"

## Set Random Seed

In [None]:
def set_seed(seed=308):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
set_seed(CONFIG.seed) # 固定随机种子，方便结果复现

## Data Progress

In [None]:
test = pd.read_csv(CONFIG.test_csv) # 读取 测试集 的数据
test

## Dataset and DataLoader

In [None]:
def transform(img):
    composition = A.Compose([
        A.Normalize(), # 与训练时的 valid 保持一致
        ToTensorV2(),
    ])
    return composition(image=img)["image"]

In [None]:
class MyDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx, :] # 从 df 中取出 idx 这一行
        img = row["pixel0":].values
        img = img.reshape(CONFIG.img_size[0], CONFIG.img_size[1], -1)
        img = np.concatenate([img] * 3, axis=-1)
        
        if self.transform != None: # 可进行 totensor 数据增强 等操作
            img = self.transform(img)
        
        return img, str(idx) # 测试集的数据没有 label，我们用它对应的索引代替

In [None]:
def prepare_loaders():
    test_datasets = MyDataset(df=test, transform=transform)
    
    test_loader = DataLoader(test_datasets, batch_size=CONFIG.test_batch_size, num_workers=CONFIG.n_workers, shuffle=False, pin_memory=True)
    # 推理过程对数据按顺序执行，不进行打乱操作，shuffle为 False
    
    return test_loader

## Model

In [None]:
class GeMPool(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeMPool, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
    
    def gem(self, x, p=3, eps=1e-6):
        return torch.mean(x.clamp(min=eps).pow(p), dim=(-2, -1)).pow(1./p)
    
    def __repr__(self):
        return self.__class__.__name__ + f'(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})'

In [None]:
class DigitRecognizerModel(nn.Module):
    def __init__(self):
        super(DigitRecognizerModel, self).__init__()
        # 推理 notebook 中 pretrained参数为 False
        self.backbone = timm.create_model(model_name=CONFIG.model_name, pretrained=CONFIG.timm_pretrained)
        
        if CONFIG.pool_name == "GeMPool":
            self.backbone.global_pool = GeMPool() # 修改 最后的全局池化层为 GeMPooling
            
        # 最后的 head 层
        in_features = self.backbone.classifier.in_features # 将预训练模型的最后一层的输出参数取出
        self.backbone.classifier = nn.Identity()
        
        self.head = nn.Sequential( # 用自定义的 head 层代替原模型中的 classifier 分类层
            nn.Linear(in_features, CONFIG.n_classes)
        )
        
        
    def forward(self, x):
        output = self.backbone(x) # 经过预训练的 模型主干backbone 后得到高级语义信息
        output = self.head(output) # 将高级语义信息传入给分类头进行最后的分类输出
        return output

## Load Model

In [None]:
model = DigitRecognizerModel() # 实例化模型
model.to(CONFIG.device)

# 加载训练得到的模型的权重参数
model.load_state_dict(torch.load(CONFIG.ckpt_path, map_location=CONFIG.device)) # map_location作用：防止不同device之间的模型与参数无法加载

## Infer Function

In [None]:
def Infer(model, test_loader):
    model.eval()
    
    y_preds = []
    bar = tqdm(enumerate(test_loader), total=len(test_loader))
    with torch.no_grad():
        for step, (img, img_id) in bar:

            img = img.to(CONFIG.device, dtype=torch.float)

            outputs = model(img)
            outputs = F.softmax(outputs)

            y_preds.append(outputs.argmax(1).detach().cpu().numpy())
            
    y_preds = np.concatenate(y_preds)

    return y_preds

## Start Infer

In [None]:
test_loader = prepare_loaders()

In [None]:
y_preds = Infer(model, test_loader)

## Make Submission

In [None]:
sub = pd.DataFrame()
sub["ImageId"] = test.index + 1
sub["Label"] = y_preds
sub

In [None]:
sub.to_csv('submission.csv', index=False)
pd.read_csv('submission.csv')