In [242]:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch import nn

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

from transformers import ViTFeatureExtractor
import torchvision.transforms.functional as F


## Hyperparameters

In [243]:
batch_size=1
learning_rate = 0.1

if torch.cuda.is_available():
	dev = "cuda"
else:
	dev = "cpu"

device = torch.device(dev)

## Download Datasets

In [244]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

## Load Data

In [245]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, shuffle = True)

## Model

In [246]:
class LinearClassifier(nn.Module):

	def __init__(
		self,
		classes: int = 10,
		width: int = 28,
		height: int = 28,
		channels: int = 1,
	):
		super().__init__()

		self.classes = classes
		self.width = width
		self.height = height
		self.channels = channels

        # Layer 1
		self.fc_1 = nn.Linear(width * height * channels, self.classes)

	def forward(self, x):
		x = torch.flatten(x, 1)
		x = self.fc_1(x)
		return x

In [247]:
class ViTModel(nn.Module):
	
	def __init__(self):
		super().__init__()

		self.classifier = LinearClassifier()
		self.model_path = 'google/vit-base-patch16-224-in21k'
		self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_path)
	
	def forward(self, x):
		print(x)
		x = self.feature_extractor(x, return_tensors='pt')
		print("LOOOOOOOOOOOOOOL")
		# print(test.shape)
		print(x)
		x = self.classifier(x)
		return x


In [253]:
class ViTModel2(nn.Module):
	
	def __init__(self):
		super().__init__()

		self.classifier = LinearClassifier()
		self.vit = ViT(
			image_size = 28,
			patch_size = 7,
			num_classes = 10,
			dim = 28,
			channels = 1
			depth = 6,
			heads = 16,
			mlp_dim = 2048,
			dropout = 0.1,
			emb_dropout = 0.1,
		)
		self.feature_extractor = Extractor(self.vit, return_embeddings_only=True)
	
	def forward(self, x):
		x = self.feature_extractor(x)
		x = self.classifier(x)
		return x

In [249]:
model = ViTModel2()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [250]:
def train(train_loader, model, loss_fn, optmizer):
	size = len(train_loader.dataset)
	batches_l = len(train_loader)
	loss = 0
	correct = 0

	for batch_idx, (data, target) in enumerate(train_loader):
		pred = model(data)
		loss = loss_fn(pred, target)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if batch_idx % 100 == 0:
			print(f'loss: {loss}')

In [251]:
def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	loss = 0
	correct_n = 0
	correct = 0

	for batch_idx, (data, target) in enumerate(test_loader):
		pred = model(data)
		argmax = pred.argmax(dim=1, keepdim=True)
		correct_n = argmax.eq(target.view_as(argmax)).sum().item()
		correct += correct_n
	return correct / size


In [252]:
train(train_loader, model, loss_fn, optimizer)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x476 and 784x10)

In [None]:
print(f'accuracy: {test(test_loader, model, loss_fn)}')

accuracy: 0.9686


: 

: 