# FINE-TUNING DiT

# ![https://images.pexels.com/photos/357514/pexels-photo-357514.jpeg?cs=srgb&dl=pexels-pixabay-357514.jpg&fm=jpg](https://images.pexels.com/photos/357514/pexels-photo-357514.jpeg?cs=srgb&dl=pexels-pixabay-357514.jpg&fm=jpg)

# IMPORT LIBRARIES

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from transformers import AutoModelForImageClassification

# CONFIGS + PREPROCESSES

In [None]:
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
num_classes = 4
model.avg_pooling = torch.nn.AdaptiveAvgPool2d(1)
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)
criterion = torch.nn.CrossEntropyLoss()

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to fit the model input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize images
])

data_folder = "./Datasets/DOCS/train/"
train_dataset = ImageFolder(data_folder, transform=transform, is_valid_file=lambda filename: not filename.endswith('.ipynb_checkpoints'))

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

optimizer = torch.optim.AdamW([
    {'params': model.base_model.parameters(), 'lr': 1e-6},  # Pre-trained layers
    {'params': model.classifier.parameters()}  # New classifier layer
], lr=1e-5, weight_decay=0.01)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# TRAIN IT UP

In [None]:
num_epochs = 2
for epoch in range(num_epochs):
    model.train()
    for step, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs, return_dict=True)
        logits = outputs.logits
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item()}")


# TEST IT UP

In [16]:
data_folder = "./Datasets/DOCS/test/"
test_dataset = ImageFolder(data_folder, transform=transform, is_valid_file=lambda filename: not filename.endswith('.ipynb_checkpoints'))

if len(test_dataset) == 0:
    raise ValueError("No images found in the dataset. Please check the 'test' subfolders.")

test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

In [29]:
# Evaluation loop
model.eval()  # Set the model to evaluation mode
total_correct = 0
total_samples = 0

with torch.no_grad():  # Disable gradient calculation for evaluation
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs, return_dict=True)
        logits = outputs.logits
        predicted_labels = logits.argmax(dim=1)
        total_correct += (predicted_labels == labels).sum().item()
        total_samples += labels.size(0)

accuracy = total_correct / total_samples
print(f"Test Accuracy: {accuracy:.2f}")

Test Accuracy: 0.85


# MANGALAM