## 0. Data downloading

In [3]:
import json
# set api key
api_token = {"username":"dd13969","key":""}
with open('/root/.kaggle/kaggle.json', 'w') as file:
    json.dump(api_token, file)

In [4]:
!kaggle competitions download -c nzmsa-2024

Downloading nzmsa-2024.zip to /kaggle/working
 87%|██████████████████████████████████▊     | 113M/130M [00:01<00:00, 76.3MB/s]
100%|████████████████████████████████████████| 130M/130M [00:01<00:00, 75.2MB/s]


In [5]:
import os
import zipfile

# unzip dataset
def unzipDataset(data_dir):
    zip_path = data_dir + '.zip'
    extract_path = os.getcwd()

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
file_name = 'nzmsa-2024'
unzipDataset(file_name)

## 1. Data loading & preprocessing

Try following data preprocessing and augmentation:
- RandomResizedCrop: useful
- RandomHorizontalFlip: useful
- RandomRotation: negtive effects
- RandomVerticalFlip: negtive effects
- Normalize: unuseful

In [11]:
import csv
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split

# data preprocessing and augmentation
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((32,32)),
#     transforms.RandomRotation(180),
    transforms.RandomHorizontalFlip(p=0.5),
#     transforms.RandomVerticalFlip(p=0.1),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
])


# define dataset
class CIFAR10Dataset(Dataset):
    """`CIFAR10 Dataset.

    Args:
        data_list (list[str]): The images files paths of the CIFAR10 Dataset.
        label_path (str): The path of label file.
        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version.
    """
    def __init__(self, data_list, label_path, transform=None):
        self.data_list = data_list
        self.label_dict = self._csv2dict(label_path)
        self.transform = transform
        if self.transform is None:
            self.transform = transforms.ToTensor()

    def _csv2dict(self, label_path):
        """Load labels from csv file"""
        label_dict = {}
        with open(label_path, mode='r', encoding='utf-8') as csv_file:
            reader = csv.DictReader(csv_file)
            for row in reader:
                label_dict[f'image_{row["id"]}.png'] = int(row['label']) 
        return label_dict

    def __len__(self):
        return len(self.data_list)
     
    def __getitem__(self, idx):
        img_path = self.data_list[idx]
        img = Image.open(img_path)
        img = img.convert("RGB")
        img_transformed = self.transform(img)
        label = self.label_dict[img_path.split('/')[-1]]
        return img_transformed, label

class TestDataset(CIFAR10Dataset):
    """`CIFAR10 test Dataset.

    Args:
        data_list (list[str]): The images files paths of the CIFAR10 Dataset.
        transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version.
    """
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform
        if self.transform is None:
            self.transform = transforms.ToTensor()

    def __getitem__(self, idx):
        img_path = self.data_list[idx]
        img = Image.open(img_path)
        img = img.convert("RGB")
        img_transformed = self.transform(img)
        id = img_path.split('_')[-1][:-4]
        return img_transformed, id


# dataset root
data_root = 'cifar10_images/train'
test_data_root = 'cifar10_images/test'
label_path = 'train.csv'

# read data path list
data_list = ['/'.join([data_root, i]) for i in os.listdir(data_root)]
test_list = ['/'.join([test_data_root, i]) for i in os.listdir(test_data_root)]
train_list, val_list = train_test_split(data_list, test_size=0.2,random_state=101)
print(f'train dataset size: {len(train_list)}, validation dataset size: {len(val_list)}, test datasetsize: {len(test_list)}')

# build train, validation and test dataset
train_dataset = CIFAR10Dataset(train_list, label_path, train_transforms)
val_dataset = CIFAR10Dataset(val_list, label_path, val_transforms)
test_dataset = TestDataset(test_list, test_transforms)

train dataset size: 40000, validation dataset size: 10000, test datasetsize: 5000


## 2. Define the model

### Resnet
- Classic CNN network.
- Very easy to achieve 90%+ score on this dataset.

In [12]:
import torch.nn as nn

class BasicBlock(nn.Module):
    """Basic Block for Resnet18 and Resnet34.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        stride (int): Stride of the first conv module, default to 1.
    """

    expansion = 1 # distinct BasicBlock and BottleNeck

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int = 1):
        super().__init__()

        # residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        # shortcut
        self.shortcut = nn.Sequential()
        # when output dimension != input dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        output = self.residual_function(x)
        output += self.shortcut(x)
        return nn.ReLU(inplace=True)(output)


class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        stride (int): Stride of the first conv module, default to 1.
    """

    expansion = 4 # distinct BasicBlock and BottleNeck

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int = 1):
        super().__init__()

        # residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        # shortcut
        self.shortcut = nn.Sequential()
        # when output dimension != input dimension
        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        output = self.residual_function(x)
        output += self.shortcut(x)
        return nn.ReLU(inplace=True)(output)


class ResNet(nn.Module):
    """General ResNet.

    Args:
        block (str): Type of residul block
        num_block (list[int]): Depth of every stage.
        num_classes (int): Determine the output dimension, default to 10.
    """

    def __init__(self,
                 block: str,
                 num_block: list,
                 num_classes: int = 10):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.layer1 = self._make_layer(block, 64, num_block[0], 1)
        self.layer2 = self._make_layer(block, 128, num_block[1], 2)
        self.layer3 = self._make_layer(block, 256, num_block[2], 2)
        self.layer4 = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.layer1(output)
        output = self.layer2(output)
        output = self.layer3(output)
        output = self.layer4(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output

# different scale of ResNet
def resnet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

def resnet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

def resnet50():
    return ResNet(BottleNeck, [3, 4, 6, 3])

def resnet101():
    return ResNet(BottleNeck, [3, 4, 23, 3])

def resnet152():
    return ResNet(BottleNeck, [3, 8, 36, 3])

### ViT
- Use transformer framework on vision tasks
- Training more slowly than ResNet
- Hard to tune hyper-parameters
- Because not allowed to use extra dataset and pre-trained model, it's hard to achieve high performance.

In [13]:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

""""
Reference: https://github.com/lucidrains/vit-pytorch
"""

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

## 3. Train & Test pipeline

Unified pipelines to train model and test on test dataset.

### train pipeline

In [14]:
import os
import time
import torch
def train(model,
          train_dataset,
          val_dataset,
          batch_size,
          epoch,
          loss_function,
          optimizer,
          output_root = './outputs',
          save_epoch = 1,
          resume = None,
          start_epoch = 1,
          device = 'cuda:0'
    ):

    model = model.to(device)

    # resume
    if resume:
        model.load_state_dict(torch.load(resume))
        model.to(device)

    # make save path
    current_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
    output_path = os.path.join(output_root, current_time)
    os.makedirs(output_path, exist_ok=True)

    # build dataloader
    train_dataloader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True )
    val_dataloader = DataLoader(dataset = val_dataset, batch_size=batch_size, shuffle=True)

    # used for logging
    iter_num =  int(len(train_dataset) / train_dataloader.batch_size)
    best_acc = 0
    f_log = open(os.path.join(output_path, 'log.log'), 'w', encoding='utf-8')
    f_loss_acc = open(os.path.join(output_path, 'loss.log'), 'w', encoding='utf-8')

    for e in range(start_epoch, epoch+1):
        train_loss = 0
        train_accuracy = 0
        
        for idx, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.to(device)

            outputs = model(data)
            loss = loss_function(outputs, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # compute accurency and loss
            acc = (outputs.argmax(dim=1) == label).float().mean()
            train_accuracy += acc / len(train_dataloader)
            train_loss += loss / len(train_dataloader)

            # logging
            if idx % 100 == 0:
                print(f'Epoch:{e}/{epoch}, iter:{idx}/{iter_num}, loss:{loss.item():.4f}')
                f_log.write(f'Epoch:{e}/{epoch}, iter:{idx}/{iter_num}, loss:{loss.item():.4f}\n')

            f_loss_acc.write(f'{loss.item():.4f}\n')

        # validation of each epoch
        label_list = []
        prediction_list = []
        with torch.no_grad():
            val_accuracy = 0
            val_loss = 0
            for idx, (data, label) in enumerate(val_dataloader):
                data = data.to(device)
                label = label.to(device)

                outputs = model(data)
                loss = loss_function(outputs, label)
                acc = (outputs.argmax(dim=1) == label).float().mean()
                val_accuracy += acc / len(val_dataloader)
                val_loss += loss / len(val_dataloader)

                label_list += label.tolist()
                prediction_list += outputs.argmax(dim=1).tolist()

        print(f'Epoch:{e}/{epoch}, train_loss:{train_loss:.4f}, train_accuracy:{train_accuracy:.4f}, val_loss:{val_loss:.4f}, val_accuracy:{val_accuracy:.4f}')
        f_log.write(f'Epoch:{e}/{epoch}, train_loss:{train_loss:.4f}, train_accuracy:{train_accuracy:.4f}, val_loss:{val_loss:.4f}, val_accuracy:{val_accuracy:.4f}\n')
        
        # model saving
        if best_acc < val_accuracy:
            model_name = f'best_epoch_{e}_{val_accuracy:.4f}.pth'
            save_path = os.path.join(output_path, model_name)
            print(f'saving best model to {save_path}')
            torch.save(model.state_dict(), save_path)
            best_acc = val_accuracy
            continue

        if epoch % save_epoch == 0:
            model_name = f'epoch_{e}_{val_accuracy:.4f}.pth'
            save_path = os.path.join(output_path, model_name)
            print(f'saving model to {save_path}')
            torch.save(model.state_dict(), save_path)
    
    
    f_loss_acc.close()
    f_log.close()


### test pipeline

In [15]:
def test(model,
         checkpoint_path,
         test_dataset,
         batch_size,
         result_path = 'submission.csv',
         device = 'cuda:0',
    ):

    model.load_state_dict(torch.load(checkpoint_path))
    model.to(device)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size=batch_size)
    
    current_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
    output_path = os.path.join('./submissions', current_time)
    os.makedirs(output_path, exist_ok=True)
    f = open(os.path.join(output_path, result_path), 'w', encoding='utf-8')
    f.write('id,label\n')

    with torch.no_grad():
        for data, ids in test_dataloader:
            data = data.to(device)

            outputs = model(data)
            labels = outputs.argmax(dim=1)
            for i in range(len(ids)):
                f.write(f'{ids[i]},{labels[i]}\n')
    f.close()

## 4. Train model

### Train ResNet

In [None]:
from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR
import torch.optim as optim

# build models
# model = resnet18()
model = resnet34()
# model = resnet50()
# model = resnet101()
# model = resnet152()

# loss function
loss_function = nn.CrossEntropyLoss()

# select optimizer
lr = 1e-4
momnetum = 0.9
weught_decay = 0.0001
optimizer = optim.Adam(model.parameters(), lr=lr)
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momnetum, weight_decay=weught_decay)

# scheduler
gamma = 0.1
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
# scheduler = MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
# secheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-5)

train(model=model,
      train_dataset=train_dataset,
      val_dataset=val_dataset,
      epoch=100,
      batch_size=32,
      loss_function=loss_function,
      optimizer=optimizer,
#       resume='',
#       start_epoch=21,
)

### Train ViT

In [None]:
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
import torch.optim as optim

# build models
model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 64,
    depth = 6,
    heads= 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1, 
)

# loss function
loss_function = nn.CrossEntropyLoss()

# optimizer
lr = 1e-3
weight_decay = 5e-5
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# scheduler
gamma = 0.1
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
# scheduler = CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-5)

train(model=model,
      train_dataset=train_dataset,
      val_dataset=val_dataset,
      epoch=200,
      batch_size=32,
      loss_function=loss_function,
      optimizer=optimizer,
#       resume='',
#       start_epoch=21,
)

## 5. Evalution

Except compare accurancy, we have many other metrics.
 

In [16]:
from sklearn.metrics import confusion_matrix
import plotly.figure_factory as ff

def getConfusionMatrix(label_list, prediction_list, save_root=None):
        cm = confusion_matrix(label_list, prediction_list, labels=range(10), normalize=None)
        # Create the list of unique labels in the test set, to use in our plot
        x = y = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

        # Plot the matrix above as a heatmap with annotations (values) in its cells
        fig = ff.create_annotated_heatmap(cm, x, y)
        # Set titles and ordering
        fig.update_layout(  title_text="<b>Confusion matrix</b>", 
                            yaxis = dict(categoryorder = "category descending"))
        fig.add_annotation(dict(font=dict(color="black",size=14),
                                x=0.5,
                                y=-0.15,
                                showarrow=False,
                                text="Predicted label",
                                xref="paper",
                                yref="paper"))
        fig.add_annotation(dict(font=dict(color="black",size=14),
                                x=-0.15,
                                y=0.5,
                                showarrow=False,
                                text="Actual label",
                                textangle=-90,
                                xref="paper",
                                yref="paper"))
        # We need margins so the titles fit
        fig.update_layout(margin=dict(t=80, r=20, l=100, b=50))
        fig['data'][0]['showscale'] = True
        fig.show()

        return cm


# import matplotlib.pyplot as plt
# from sklearn.metrics import roc_curve, auc, roc_auc_score, recall_score, f1_score
# from sklearn.utils.multiclass import type_of_target
# import numpy as np
# from sklearn.preprocessing import label_binarize


# # def getRocAuc(y_true, y_scores):
# #     y_one_hot = label_binarize(y_true, classes=np.arange(10))
# #     fpr, tpr, threshold = roc_curve(y_one_hot.ravel(), y_scores.ravel())
# #     roc_auc = auc(fpr, tpr)
# #     plt.figure()
# #     lw = 2
# #     plt.figure(figsize=(10, 10))
# #     plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)  # 假正率为横坐标，真正率为纵坐标做曲线
# #     plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
# #     plt.xlim([0.0, 1.0])
# #     plt.ylim([0.0, 1.05])
# #     plt.legend(loc="lower right")
# #     plt.show()

In [22]:
def evaluate(model,
         checkpoint_path,
         val_dataset,
         batch_size,
         device = 'cuda:0',
    ):

    model.load_state_dict(torch.load(checkpoint_path))
    model.to(device)
    val_dataloader = DataLoader(dataset = val_dataset, batch_size=batch_size)

    label_list = []
    prediction_list = []
    with torch.no_grad():
        val_accuracy = 0
        for idx, (data, label) in enumerate(val_dataloader):
            data = data.to(device)
            label = label.to(device)

            outputs = model(data)
            acc = (outputs.argmax(dim=1) == label).float().mean()
            val_accuracy += acc / len(val_dataloader)

            label_list += label.tolist()
            prediction_list += outputs.argmax(dim=1).tolist()
            # if idx == 0:
            #     prediction_scores  = outputs.cpu().numpy()
            # else:
            #     prediction_scores = np.concatenate((prediction_scores , outputs.cpu().numpy()), axis=0)
        # print(outputs)
        # print(outputs.argmax(dim=1))
        # print(prediction_scores.shape)
        # plot_roc(label_list, prediction_scores)
        getConfusionMatrix(label_list, prediction_list)
evaluate(model=resnet34(),
     checkpoint_path='./best_epoch_98_0.9131.pth',
     val_dataset=val_dataset,
     batch_size=32
)

## 6. Ensemble Learning

In [None]:
def EnsembleLearning(model1,
         checkpoint_path1,
         model2,
         checkpoint_path2,
         test_dataset,
         batch_size,
         result_path = 'submission.csv',
         device = 'cuda:0',
    ):

    model1.load_state_dict(torch.load(checkpoint_path1))
    model1.to(device)
    model2.load_state_dict(torch.load(checkpoint_path2))
    model2.to(device)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size=batch_size)
    
    current_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
    output_path = os.path.join('./submissions', current_time)
    os.makedirs(output_path, exist_ok=True)
    f = open(os.path.join(output_path, result_path), 'w', encoding='utf-8')
    f.write('id,label\n')

    with torch.no_grad():
        for data, ids in test_dataloader:
            data = data.to(device)
            outputs = 4*model1(data) + model2(data) 
            labels = outputs.argmax(dim=1)
            for i in range(len(ids)):
                f.write(f'{ids[i]},{labels[i]}\n')
    f.close()

model1 = resnet34()
model2 = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 64,
    depth = 6,
    heads= 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1, 
)

EnsembleLearning(model1=model1,
     checkpoint_path1='./best_epoch_98_0.9131.pth',
     model2=model2,
     checkpoint_path2='./best_epoch_382_0.8149.pth',
     test_dataset=test_dataset,
     batch_size=32
)   