In [24]:
import timm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
import torchvision
from torch.optim import Adam
import os
from PIL import Image

In [25]:
train_healthy_dir = './database/train_data/healthy'
train_sick_dir = './database/train_data/sick'
test_healthy_dir = './database/test_data/healthy'
test_sick_dir = './database/test_data/sick'

In [26]:
def get_images(path, label):
    ret = []
    for dir in os.listdir(path):
        dir += '/Segmentadas'
        for file in os.listdir(path+'/'+dir):
            ret.append((path+'/'+dir+'/'+file, label))
    
    return ret

In [27]:
train_data = get_images(train_healthy_dir, 0) + get_images(train_sick_dir, 1)
test_data = get_images(test_healthy_dir, 0) + get_images(test_sick_dir, 1)

In [28]:
from PIL import Image
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Usage:
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

transform = Compose([
    Resize((224, 224)),  # Resize images to 224x224
    ToTensor(),  # Convert PIL image to tensor
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize images
])

def get_torch_dataset(dataset):
    return CustomImageDataset([i[0] for i in dataset], [i[1] for i in dataset], transform=transform)
# Assuming image_paths and labels are your data


In [29]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
num_classes = 2  # Number of classes in your dataset
model.head = nn.Linear(model.head.in_features, num_classes)

# Load the CIFAR10 dataset
train_dataset = get_torch_dataset(train_data)
test_dataset = get_torch_dataset(test_data)

# Define the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

# Train the model
for epoch in range(10):  # Number of epochs
    for images, labels in train_loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f'Epoch: {epoch + 1}, loss: {loss.item():.3f}')
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')

model.safetensors: 100%|██████████| 346M/346M [00:29<00:00, 11.6MB/s] 
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


KeyboardInterrupt: 