In [12]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from vit_pytorch import ViT
import os


In [4]:
!pip install vit-pytorch


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting vit-pytorch
  Downloading vit_pytorch-1.2.1-py3-none-any.whl (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.1 (from vit-pytorch)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, vit-pytorch
Successfully installed einops-0.6.1 vit-pytorch-1.2.1


In [6]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [15]:
class MedicalImages(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform

        # Load image file paths and infer labels from folder names
        self.image_labels = []
        for label, folder_name in enumerate(sorted(os.listdir(data_path))):
            folder_path = os.path.join(data_path, folder_name)
            if os.path.isdir(folder_path):
                for filename in os.listdir(folder_path):
                    file_path = os.path.join(folder_path, filename)
                    if os.path.isfile(file_path):
                        self.image_labels.append((file_path, label))

    def __len__(self):
        return len(self.image_labels)

    def __getitem__(self, index):
        # Load image and label
        image_path, label = self.image_labels[index]
        image = Image.open(image_path).convert('RGB')

        # Apply transformations
        if self.transform is not None:
            image = self.transform(image)

        return image, label



In [10]:
from torchvision import transforms

# Define data transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [17]:
# Load data
data_path = '/content/drive/MyDrive/dataset_medial/train'
dataset = MedicalImages(data_path, transform=transform)
loader = DataLoader(dataset, batch_size=16, shuffle=True)


In [18]:
# Define model
model = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 5,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train model
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(loader):
        # Forward pass
        output = model(data)
        loss = criterion(output, target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss after every 100 batches
        if batch_idx % 100 == 0:
            print('Epoch: {}, Batch: {}, Loss: {:.4f}'.format(epoch+1, batch_idx, loss.item()))


Epoch: 1, Batch: 0, Loss: 0.3883
Epoch: 2, Batch: 0, Loss: 0.0000
Epoch: 3, Batch: 0, Loss: 0.0000
Epoch: 4, Batch: 0, Loss: 0.0000
Epoch: 5, Batch: 0, Loss: 0.0000
Epoch: 6, Batch: 0, Loss: 0.0000
Epoch: 7, Batch: 0, Loss: 0.0000
Epoch: 8, Batch: 0, Loss: 0.0000
Epoch: 9, Batch: 0, Loss: 0.0000
Epoch: 10, Batch: 0, Loss: 0.0000


In [19]:
# Load test dataset
test_dataset = MedicalImages('/content/drive/MyDrive/dataset_medial/test', transform=transform)

# Set model to evaluation mode
model.eval()

# Disable gradient calculation
with torch.no_grad():
    # Iterate through test dataset
    num_correct = 0
    num_samples = 0
    for data, target in test_dataset:
        # Get predicted labels using model
        output = model(data.unsqueeze(0))
        _, predicted = torch.max(output, dim=1)

        # Update number of correct predictions and number of samples
        num_correct += (predicted == target).sum().item()
        num_samples += 1

    # Calculate accuracy
    accuracy = num_correct / num_samples
    print('Accuracy: {:.2f}%'.format(accuracy * 100))


Accuracy: 38.59%


In [None]:
from google.colab import drive
drive.mount('/content/drive')