In [7]:
from transformers import ViTForImageClassification
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

import pandas as pd
import numpy as np

## Data

In [19]:
test_df = pd.read_csv('/kaggle/input/digit-recognizer/test.csv')

In [20]:
X_test = torch.tensor(test_df.values, dtype=torch.float32) / 255.0
X_test = X_test.view(-1, 1, 28, 28)
X_test = X_test.repeat(1, 3, 1, 1)

In [40]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [41]:
class MNISTDataset(Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        if self.transform:
            image = self.transform(image)

        return image

In [42]:
test_dataset = MNISTDataset(X_test, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Model

In [None]:
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", num_labels=10, ignore_mismatched_sizes=True)

In [None]:
model.load_state_dict(torch.load("/kaggle/input/vit-for-mnist/transformers/default/1/fine_tuned_vit_mnist.pth"))

In [43]:
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [46]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

all_predictions = []

with torch.no_grad():
    for inputs in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs).logits
        _, predictions = torch.max(outputs, 1)
        all_predictions.append(predictions.cpu().numpy())

In [47]:
all_predictions = np.concatenate(all_predictions)
predictions_list = all_predictions.tolist()

In [48]:
submission = pd.DataFrame(predictions_list)
submission.index.name='ImageId'
submission.index+=1
submission.columns=['Label']

In [49]:
submission.to_csv('submission.csv', index="ImageId")