# FINE-TUNING DiT

# ![https://images.pexels.com/photos/357514/pexels-photo-357514.jpeg?cs=srgb&dl=pexels-pixabay-357514.jpg&fm=jpg](https://images.pexels.com/photos/357514/pexels-photo-357514.jpeg?cs=srgb&dl=pexels-pixabay-357514.jpg&fm=jpg)

In [19]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import glob, os

# MODEL CONFIGS

In [3]:
model = BeitForMaskedImageModeling.from_pretrained("microsoft/dit-base")
num_classes = 4
model.avg_pooling = torch.nn.AdaptiveAvgPool2d(1)
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)
criterion = torch.nn.CrossEntropyLoss()

# CUSTOM DATASET

In [4]:
class CustomDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = []
        self.labels = []

        self._load_images()

    def _load_images(self):
        valid_extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        for class_name in os.listdir(self.folder_path):
            class_folder = os.path.join(self.folder_path, class_name)
            if os.path.isdir(class_folder):
                for filename in os.listdir(class_folder):
                    if filename.lower().endswith(valid_extensions):
                        self.image_paths.append(os.path.join(class_folder, filename))
                        self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Convert label to tensor
        label_tensor = torch.tensor(int(label))  # Assuming the labels are integer class indices

        return image, label_tensor

# TRAIN LOADER

In [5]:
data_folder = "./Datasets/DOCS_V1/train/"

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to fit the model input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize images
])

train_dataset = CustomDataset(data_folder, transform=transform)

if len(train_dataset) == 0:
    raise ValueError("No images found in the dataset. Please check the 'train' subfolders.")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
num_epochs = 4

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
num_epochs = 5

# TRAIN

In [None]:
for epoch in range(num_epochs):
    model.train()
    for step, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(pixel_values=inputs, return_dict=True)
        logits = outputs.logits
        batch_size = labels.size(0)
        logits = logits.view(batch_size, -1)
        labels = labels.view(-1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")

# TESTING

In [13]:
data_folder = "./Datasets/DOCS_V1/test/"
test_dataset = CustomDataset(data_folder, transform=transform)

if len(test_dataset) == 0:
    raise ValueError("No images found in the dataset. Please check the 'train' subfolders.")

test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

In [14]:
import torch.nn.functional as F

model.eval()
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs, return_dict=True)
        logits = outputs.logits

        # Perform max pooling along the second dimension (axis 1)
        pooled_logits, _ = torch.max(logits, dim=1)

        # Convert pooled_logits to probabilities using softmax
        probabilities = F.softmax(pooled_logits, dim=-1)
        predictions = torch.argmax(probabilities, dim=-1)

        correct_predictions += (predictions == labels).sum().item()
        total_predictions += labels.size(0)

print("Total correct predictions:", correct_predictions)
print("Total predictions made:", total_predictions)

accuracy = correct_predictions / total_predictions * 100
print("Accuracy:", accuracy)


Total correct predictions: 906
Total predictions made: 1187
Accuracy: 76.32687447346251


In [15]:
torch.save(model,'./modelsave.pt')

In [49]:
labels.shape

torch.Size([4])