In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import os
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
from torchvision.transforms import Resize
import torch.optim as optim
from torch.utils.data import DataLoader , TensorDataset
from torch import Tensor

from torchvision.models import resnet18
from einops.layers.torch import Rearrange
from einops import rearrange


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from numpy import load
x_train = load('x_train.npy')
y_train = load('y_train.npy')
x_test = load('x_test.npy')
y_test = load('y_test.npy')

In [None]:
torch.cuda.is_available()

In [None]:
dataset_train = TensorDataset( Tensor(x_train), Tensor(y_train).long() )
dataset_test = TensorDataset( Tensor(x_test), Tensor(y_test).long())


image_size = 224

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Apply transformations to custom datasets
dataset_train.transform = train_transform
dataset_test.transform = test_transform

# Create DataLoader objects
batch_size = 32
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True).to()
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)




In [None]:
# Define Vision Transformer model
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim):
        super(ViT, self).__init__()
        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (image_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=8), num_layers=6)
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # Transpose dimensions to match expected input shape
        B, C, H, W = x.shape
        x = self.patch_embedding(x)
        x = rearrange(x, 'b c h w -> b (h w) c')  # flatten spatial dimensions
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.positional_embedding
        x = self.transformer_encoder(x)
        x = x[:, 0]  # take the cls token
        x = self.fc(x)
        return x

# Training parameters
batch_size = 64
image_size = 224
patch_size = 16
num_classes = 2
dim = 768
num_epochs = 5



In [None]:
# Initialize model, loss, and optimizer
model = ViT(image_size, patch_size, num_classes, dim)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-3)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total}%")