In [2]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
from transformers import ViTForImageClassification, ViTFeatureExtractor

In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
train_code_dir = "/home/bartek/Kod/PD/praca_dyplomowa/dane/resnet_dane/ready/train/code"  # Contains 'code' and 'non_code' subfolders
train_no_code_dir = "/home/bartek/Kod/PD/praca_dyplomowa/dane/resnet_dane/ready/train/no_code"  # Contains 'code' and 'non_code' subfolders
val_code_dir = "/home/bartek/Kod/PD/praca_dyplomowa/dane/resnet_dane/ready/val/code"      # Contains 'code' and 'non_code' subfolders
val_no_code_dir = "/home/bartek/Kod/PD/praca_dyplomowa/dane/resnet_dane/ready/val/no_code"      # Contains 'code' and 'non_code' subfolders

In [7]:
model_name = "google/vit-base-patch16-224"

In [8]:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



In [9]:
# Load images from folders with existing structure
def load_dataset_from_folders(code_dir, no_code_dir):
    images = []
    labels = []
    
    # Load code images (label 1)
    for img_name in os.listdir(code_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            images.append(os.path.join(code_dir, img_name))
            labels.append(1)
    
    # Load non-code images (label 0)
    for img_name in os.listdir(no_code_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            images.append(os.path.join(no_code_dir, img_name))
            labels.append(0)
    
    return images, labels

# Load training and validation datasets
train_images, train_labels = load_dataset_from_folders(train_code_dir, train_no_code_dir)
val_images, val_labels = load_dataset_from_folders(val_code_dir, val_no_code_dir)

print(f"Training samples: {len(train_images)}")
print(f"Validation samples: {len(val_images)}")

Training samples: 9044
Validation samples: 2262


In [10]:
# Function to preprocess a batch of images
def preprocess_batch(image_paths, labels):
    images = []
    for img_path in image_paths:
        img = Image.open(img_path).convert("RGB")
        images.append(img)
    
    # Process images using the ViT feature extractor
    inputs = feature_extractor(images=images, return_tensors="pt")
    inputs['labels'] = torch.tensor(labels)
    return inputs

In [11]:
# Create dataloader function
def create_dataloader(image_paths, labels, batch_size=16, shuffle=True):
    indices = list(range(len(image_paths)))
    if shuffle:
        np.random.shuffle(indices)
    
    # Create mini-batches
    batches = []
    for i in range(0, len(indices), batch_size):
        batch_indices = indices[i:i + batch_size]
        batch_images = [image_paths[idx] for idx in batch_indices]
        batch_labels = [labels[idx] for idx in batch_indices]
        batches.append((batch_images, batch_labels))
    
    return batches

In [12]:
# Load pre-trained ViT model and modify for binary classification
model = ViTForImageClassification.from_pretrained(model_name, num_labels=2)
model.to(device)

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

RuntimeError: Error(s) in loading state_dict for ViTForImageClassification:
	size mismatch for classifier.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([2, 768]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([2]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

In [None]:
# Training function
def train_model(model, train_dataloader, val_dataloader, epochs=3):
    best_accuracy = 0
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        
        for batch_images, batch_labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs = preprocess_batch(batch_images, batch_labels)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            optimizer.zero_grad()
            outputs = model(**inputs)
            loss = outputs.loss
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_dataloader)
        
        # Validation phase
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch_images, batch_labels in tqdm(val_dataloader, desc="Validating"):
                inputs = preprocess_batch(batch_images, batch_labels)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                outputs = model(**inputs)
                _, predicted = torch.max(outputs.logits, 1)
                labels = inputs['labels']
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            # Save the best model
            torch.save(model.state_dict(), "vit_code_classifier.pth")
    
    return model

In [None]:
# Create data loaders
batch_size = 16
train_dataloader = create_dataloader(train_images, train_labels, batch_size)
val_dataloader = create_dataloader(val_images, val_labels, batch_size, shuffle=False)

# Train the model
trained_model = train_model(model, train_dataloader, val_dataloader, epochs=3)

In [None]:
# Function for inference on a single image
def predict_image(model, image_path):
    model.eval()
    img = Image.open(image_path).convert("RGB")
    inputs = feature_extractor(images=[img], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
        prediction = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][prediction].item()
    
    label = "Code" if prediction == 1 else "Not Code"
    return label, confidence

In [None]:
# Example usage for inference
test_image = "path/to/test/image.png"
label, confidence = predict_image(trained_model, test_image)
print(f"Prediction: {label} (Confidence: {confidence:.2f})")

In [None]:
# Function for batch inference on a folder of images
def batch_predict(model, image_dir):
    model.eval()
    results = []
    
    for img_name in os.listdir(image_dir):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(image_dir, img_name)
            label, confidence = predict_image(model, img_path)
            results.append((img_name, label, confidence))
    
    return results

# Optional: How to load a saved model later
# model = ViTForImageClassification.from_pretrained(model_name, num_labels=2)
# model.load_state_dict(torch.load("vit_code_classifier.pth"))
# model.to(device)