# import

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import math
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from  PIL import Image
import os
import random
from tqdm import tqdm
import timm
import sys
sys.path.insert(0, '../input/tiny-vit-model')
import tiny_vit

# 定义参数

In [2]:
INPUT_PATH = '../input/1103-tinyvit/'
TRAIN_CSV_PATH = '../input/cassava-leaf-disease-classification/train.csv'
TEST_IMAGE_PATH = '../input/cassava-leaf-disease-classification/test_images/'
SUBMISSION_PATH = 'submission.csv'
TINY_VIT_PATH = 'tinyvit.pth'
DEVICES = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
OUT_CLASSES = 5
IMAGE_SIZE = 512
OPTIMIZER = torch.optim.AdamW
SEED = 42
TTA = 3

# tr5.seed

In [7]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
seed_everything(SEED)

# 定义模型

In [8]:
my_model_1 = tiny_vit.tiny_vit_21m_512(pretrained=False)
my_model_1.head = nn.Linear(my_model_1.head.in_features, OUT_CLASSES)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_

In [10]:
torch.cuda.empty_cache()

# 验证

In [11]:
test_augs = A.Compose([
    A.OneOf([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE, p=1.0),
        A.CenterCrop(IMAGE_SIZE, IMAGE_SIZE, p=1.0),
        A.RandomResizedCrop(IMAGE_SIZE, IMAGE_SIZE, p=1.0)
    ], p=1.0),
    A.Transpose(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0)
    ], p=1.0
)


preds_1 = []
model_param = torch.load(os.path.join(INPUT_PATH, TINY_VIT_PATH))
new_model_param = {k[7:]: v for k, v in model_param.items() if 'module.' in k}
my_model_1.load_state_dict(new_model_param)
my_model_1 = my_model_1.to(DEVICES[0])
my_model_1.eval()
test_image_list = np.asarray([image_name for image_name in os.listdir(TEST_IMAGE_PATH)])
for single_image_name in test_image_list:
    with torch.no_grad():
        ans = torch.zeros(5).cuda()
        for _ in range(TTA):
            image = Image.open(os.path.join(TEST_IMAGE_PATH, single_image_name))
            aug_image = test_augs(image=np.array(image))['image']
            test_image = torch.tensor(aug_image, dtype=torch.float).unsqueeze(0).cuda()
            ans += my_model_1(test_image).view(ans.shape)
        ans /= TTA
        preds_1.append(ans)
predictions_1 = torch.stack(preds_1, dim=0).to('cpu')
normalize_pred_1 = F.normalize(predictions_1.T, p=2, dim=0).T

label = normalize_pred_1.argmax(dim=-1).numpy()
label_list = list(label)
df_submission = pd.DataFrame(columns=pd.read_csv(TRAIN_CSV_PATH).columns)
df_submission['image_id'] = pd.DataFrame(test_image_list)
df_submission['label'] = pd.DataFrame(label_list)
df_submission.to_csv(SUBMISSION_PATH, index=False)

  test_image = torch.tensor(aug_image, dtype=torch.float).unsqueeze(0).cuda()
  test_image = torch.tensor(aug_image, dtype=torch.float).unsqueeze(0).cuda()
