# ViT Small Pretrained on DINOv2 with registers 
## Data: Myxococcaceae vs non-Myxococcaceae
## Augmentation: TrivialAugmentWide to 60000 samples



### import requirements

In [1]:
import copy
import torchvision.models as models
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms, autoaugment
from sklearn.utils import resample
import os
import numpy as np
import matplotlib.pyplot as plt

## Preparing Data

### Loading Original Dataset

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'your device is {device}')

# Defining data transforms
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop((224, 224))
])

data_path = 'D:\Master Project\model\model-1\myxo-vs-nonmyxo-V2-9p'
dataset = ImageFolder(root=data_path, transform=data_transforms)

batch_size = 32
validation_split = 0.2

# Create indices for splitting
targets = torch.tensor(dataset.targets)
class_counts = targets.unique(return_counts=True)[1]
num_classes = len(class_counts)

# Calculating indices for each class to maintain the ratio
indices = torch.randperm(len(dataset)).tolist()
train_idx, val_idx = [], []

for class_index in range(num_classes):
    class_indices = [i for i in range(len(targets)) if targets[i] == class_index]
    split = int(len(class_indices) * validation_split)
    class_indices = torch.tensor(class_indices)[torch.randperm(len(class_indices))].tolist()
    
    val_idx += class_indices[:split]
    train_idx += class_indices[split:]

# Defining the subsets for training and validation
original_datasets = {
    'train': Subset(dataset, train_idx),
    'val': Subset(dataset, val_idx)
}
print('datasets have been created')

original_dataloaders = {x: DataLoader(dataset=original_datasets[x], batch_size=batch_size, num_workers=2,
                                      shuffle=True if x == 'train' else False, drop_last=True)
                        for x in ['train', 'val']}
print('dataloaders have been created')

class_names = dataset.classes
num_classes = len(class_names)
print(f'there are {num_classes} classes, and class names are {class_names}')

class_counts_dict = {x: len(original_datasets[x]) for x in ['train', 'val']}
print(f'Dataset sizes: {class_counts_dict}')

your device is cuda
datasets have been created
dataloaders have been created
there are 2 classes, and class names are ['Myxococcaceae', 'non-Myxococcaceae']
Dataset sizes: {'train': 4558, 'val': 1139}


### Counting Classes 

In [3]:
from collections import Counter

class_counts = Counter()

for phase in ['train', 'val']:
    for _, label in original_dataloaders[phase]:
        class_counts.update(label.tolist())

# show details
for label, count in class_counts.items():
    print(f'Class {label}: {count} instances')

Class 0: 4847 instances
Class 1: 817 instances


### Defining Augmentation Class 

In [4]:
class CustomAugmentedDataset(Dataset):
    def __init__(self, class_names, num_samples_per_class, root_dir=None, dataset=None, transform=None,
                 num_magnitude_bins=30):
        """
        Args:
            root_dir (string): Directory with all the images.
            num_samples_per_class (int): Desired number of samples per class after augmentation.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        if dataset:
            self.dataset = dataset
            self.situation = 'dataset'
        else:
            self.dataset = ImageFolder(root=root_dir)
            self.situation = 'root_dir'
        self.classes = class_names
        self.num_samples_per_class = num_samples_per_class
        self.transform = transform
        self.augment_transform = transforms.Compose([
            autoaugment.TrivialAugmentWide(num_magnitude_bins=num_magnitude_bins),
            transforms.ToTensor(),
        ])
        self.class_samples = self._balance_classes()

    def _balance_classes(self):
        from collections import defaultdict
        class_indices = defaultdict(list)
        if self.situation == 'root_dir':
            for idx, (_, class_id) in enumerate(self.dataset.samples):
                class_indices[class_id].append(idx)
        else:
            for idx, (_, class_id) in enumerate(self.dataset):
                class_indices[class_id].append(idx)

        # Reduce or oversample class indices to match num_samples_per_class
        balanced_indices = []
        for indices in class_indices.values():
            if len(indices) >= self.num_samples_per_class:
                balanced_indices.extend(indices[:self.num_samples_per_class])
            else:
                # Oversample if there are fewer samples than desired
                oversampled_indices = indices * (self.num_samples_per_class // len(indices)) + indices[
                                                                                               :self.num_samples_per_class % len(
                                                                                                   indices)]
                balanced_indices.extend(oversampled_indices)

        return balanced_indices

    def __len__(self):
        return len(self.class_samples)

    def __getitem__(self, idx):
        img, label = self.dataset[self.class_samples[idx]]
        if self.transform is not None:
            img = self.transform(img)
        else:
            img = self.augment_transform(img)
        return img, label

    def classes(self):
        return self.classes

### Creating Augmented Dataset

In [5]:
# Define Parameters
num_magnitude_bins = 100
num_samples_per_class = 30000

# Define any additional transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Creating datasets
datasets = {
    x: CustomAugmentedDataset(class_names=class_names, dataset=original_datasets[x], transform=transform,
                              num_magnitude_bins=num_magnitude_bins,
                              num_samples_per_class=num_samples_per_class if x == 'train' else 3000)
    for x in ['train', 'val']
}
print('Datasets created.')

# Creating dataloaders
batch_size = 32
dataloaders = {
    x: DataLoader(dataset=datasets[x], batch_size=batch_size, num_workers=2, shuffle=True if x == 'train' else False,
                  drop_last=True)
    for x in ['train', 'val']
}
print('Dataloaders created.')
print('-' * 50)

# Show Classes
class_names = datasets['train'].classes
print(f'there are {len(class_names)} classes, and class names are {class_names}')
print('-' * 50)

# Show datasets length 
class_counts_dict = {x: len(datasets[x]) for x in ['train', 'val']}
print(f'Dataset sizes: {class_counts_dict}')



Datasets created.
Dataloaders created.
--------------------------------------------------
there are 2 classes, and class names are ['Myxococcaceae', 'non-Myxococcaceae']
--------------------------------------------------
Dataset sizes: {'train': 60000, 'val': 6000}


### Counting Classes

In [None]:
from collections import Counter

class_counts = Counter()

for phase in ['train', 'val']:
    for _, label in dataloaders[phase]:
        class_counts.update(label.tolist())

# show details
for label, count in class_counts.items():
    print(f'Class {label}: {count} instances')

# Showing augmented data sample 

In [6]:
import matplotlib.pyplot as plt

idx = np.random.randint(0, 60000, size=1)

print('idx: ', int(idx))
print('idx type: ', type(idx))
image, label = datasets['train'][int(idx)]
print('image type: ', type(image))
print('label: ', label)
print('class name label: ', class_names[label])

# Convert torch tensor for plotting
image = image.permute(1, 2, 0)
plt.grid(False)
plt.axis('off')
plt.imshow(image)
plt.show()


  print('idx: ', int(idx))
  image, label = datasets['train'][int(idx)]


idx:  21637
idx type:  <class 'numpy.ndarray'>


TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

# Load ViT pretrained on DINOv2 with registers model

In [None]:
# DINOv2
dinov2_vits14_21M = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2_vitb14_86M = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
# dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
# dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

# DINOv2 with registers
dinov2_vits14_reg_21M = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
dinov2_vitb14_reg_86M = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
# dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
# dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')

# Fine-Tune model classifier and trainable parameters 

In [None]:
# Define model
model = dinov2_vits14_reg_21M

# Define classifier for Binary Classification task
model.head = nn.Sequential(
    nn.Linear(384, len(class_names))
)
print(dinov2_vits14_reg_21M)

# Set about 30% of parameters trainable 
model_params = 0
for idx, param in enumerate(model.parameters()):
    param.requires_grad = False
    model_params += 1
    if idx == 125:
        break

## Defining Train function 

In [None]:
from datetime import datetime
from easydict import EasyDict


# train function 
def train_model(model, criterion, optimizer, dataloaders, datasets, epoch_num=25):
    acc_list = EasyDict({'train': [], 'val': []})
    loss_list = EasyDict({'train': [], 'val': []})

    # Copy the best model weights for loading at the End
    best_model_wts = copy.deepcopy(model.state_dict())
    best_accuracy = 0.0

    # Iterating over epochs
    for epoch in range(1, epoch_num + 1):
        print(f'Epoch {epoch}/{epoch_num}:')

        # Each epoch has two phase Train and Validation
        for phase in ['train', 'val']:
            s0 = datetime.now()
            if phase == 'train':
                model.train()
            else:
                model.eval()

            # For calculating Loss and Accuracy at the end of epoch
            running_loss = 0.0
            running_corrects = 0.0

            # Iterating over batches and data for training and validation
            for idx, batch in enumerate(dataloaders[phase], 0):
                inputs, labels = batch

                # Transfer data and labels to CUDA if is available
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Forward Pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)

                    assert outputs.shape[1] == 15, "Output size does not match number of classes"

                    loss = criterion(outputs, labels)

                    assert labels.min() >= 0 and labels.max() < 15, "Labels are out of range"

                    _, predictions = torch.max(outputs, 1)

                    # Back Propagation and updating weights
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(predictions == labels.data)

            # Calculating Accuracy and Loss per phase
            epoch_loss = running_loss / len(datasets[phase])
            epoch_accuracy = running_corrects / len(datasets[phase])

            # Show epoch details
            delta = datetime.now() - s0
            print(f'{phase.capitalize()} Accuracy: {epoch_accuracy:.4f} | Loss: {epoch_loss:.4f} | time: {delta}')

            # Copy the model weights if its better
            if phase == 'val' and epoch_accuracy > best_accuracy:
                best_accuracy = epoch_accuracy
                best_model_wts = copy.deepcopy(model.state_dict())
                print('Best model weights updated!')

            # Save Loss and accuracy
            acc_list[phase].append(epoch_accuracy)
            loss_list[phase].append(epoch_loss)
        print('-' * 50)

    print(f'Best Accuracy: {best_accuracy:.4f}')

    # Loading best model weights 
    model.load_state_dict(best_model_wts)
    return model, acc_list, loss_list

# Train ViT-s DINOv2 with registers 
---------------
## Hyperparameters:
### optimizer: Adam
### criterion: CrossEntropy
### Learning Rate: 0.001
### batch size: 32
### epoch: 50

In [None]:
# Defining hyperparameters
criterion = CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'your device is {device}', end='\n\n')
optimizer = Adam(model.parameters(), lr=0.001)
model = model.to(device)
print(model)
print('-' * 50)

# train model
model, acc_lists, loss_lists = train_model(model, criterion, optimizer, dataloaders, datasets, epoch_num=40)

## Plot Results 

In [None]:
plt.plot([a.cpu() for a in acc_lists.train], label='train')
plt.plot([a.cpu() for a in acc_lists.val], label='val')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy Percent')
plt.legend()
plt.show()

In [None]:
plt.plot([a for a in loss_lists.train], label='train loss')
plt.plot([a for a in loss_lists.val], label='val loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Percent')
plt.legend()
plt.show()


In [None]:
plt.plot([a.cpu() for a in acc_lists.train], label='train acc')
plt.plot([a.cpu() for a in acc_lists.val], label='val acc')
plt.plot([a for a in loss_lists.train], label='train loss')
plt.plot([a for a in loss_lists.val], label='val loss')
plt.title('result')
plt.xlabel('Epoch')
plt.ylabel('Accuracy Percent')
plt.legend()
plt.show()

## Save best model weights

In [None]:
torch.save(model, 'models/model_2.pth')

## Visualize model predictions

In [None]:
def visualize_model(model):
    model.eval()
    nrows, ncols = 4, 4
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 10))

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predictions = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                img = inputs.cpu().data[j]
                img = img.numpy().transpose((1, 2, 0))
                # img = std * img + mean
                img = np.clip(img, 0, 1)
                axes[i][j].axis('off')
                axes[i][j].set_title(
                    f'predictions: {class_names[predictions[j]]}, label: {class_names[labels[j]]}'
                )
                axes[i][j].imshow(img)
                if j == ncols - 1:
                    break
            if i == nrows - 1:
                break
    plt.savefig('vis.jpg')


model = torch.load('models/model_2.pth')
visualize_model(model)

## Plot Confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sns


def plot_cm(model):
    y_true, y_pred = [], []
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            outputs = model(inputs)

            outputs = (torch.max(torch.exp(outputs), 1)[1]).data.cpu().numpy()
            y_pred.extend(outputs)

            labels = labels.data.cpu().numpy()
            y_true.extend(labels)

    cm = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(
        cm / np.sum(cm, axis=1)[:, None],
        index=[i for i in class_names],
        columns=[i for i in class_names]
    )

    plt.figure(figsize=(15, 10))
    sns.heatmap(df_cm, annot=True, cbar=False)
    plt.show()


plot_cm(model)

# Fine-Tune model2 classifier and trainable parameters 

In [None]:
# Define model
model2 = dinov2_vits14_reg_21M

# Define classifier for Binary Classification task
model2.head = nn.Sequential(
    nn.Linear(384, 2)
)

# Set about 30% of parameters trainable 
model_params = 0
for idx, param in enumerate(model2.parameters()):
    param.requires_grad = False
    model_params += 1
    if idx == 125:
        break

# Train ViT-s DINOv2 with registers 
---------------
## Hyperparameters:
### optimizer: Adam
### criterion: CrossEntropy
### Learning Rate: 0.0003
### batch size: 32
### epoch: 50

In [None]:
# Defining Hyperparameters 
criterion = CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'your device is {device}', end='\n\n')
optimizer = Adam(model2.parameters(), lr=0.0003)
model2 = model2.to(device)
print(model2)
print('-' * 50)

# train model
model2, acc_lists2, loss_lists2 = train_model(model2, criterion, optimizer, dataloaders, datasets, epoch_num=50)

# Plot results

In [None]:
plt.plot([a.cpu() for a in acc_lists2.train], label='train')
plt.plot([a.cpu() for a in acc_lists2.val], label='val')
plt.title('Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy Percent')
plt.legend()
plt.show()

In [None]:
plt.plot([a for a in loss_lists2.train], label='train loss')
plt.plot([a for a in loss_lists2.val], label='val loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Percent')
plt.legend()
plt.show()


In [None]:
plt.plot([a.cpu() for a in acc_lists2.train], label='train acc')
plt.plot([a.cpu() for a in acc_lists2.val], label='val acc')
plt.plot([a for a in loss_lists2.train], label='train loss')
plt.plot([a for a in loss_lists2.val], label='val loss')
plt.title('result')
plt.xlabel('Epoch')
plt.ylabel('Accuracy Percent')
plt.legend()
plt.show()

# Visualize model predictions

In [None]:
def visualize_model(model):
    model.eval()
    nrows, ncols = 4, 4
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 10))

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predictions = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                img = inputs.cpu().data[j]
                img = img.numpy().transpose((1, 2, 0))
                # img = std * img + mean
                img = np.clip(img, 0, 1)
                axes[i][j].axis('off')
                axes[i][j].set_title(
                    f'predictions: {class_names[predictions[j]]}, label: {class_names[labels[j]]}'
                )
                axes[i][j].imshow(img)
                if j == ncols - 1:
                    break
            if i == nrows - 1:
                break
    plt.savefig('vis.jpg')


visualize_model(model2)

## Save best model weights

In [None]:
torch.save(model, 'models/model_3.pth')